├── experiments ├── __init__.py ├── nyuv2 │ ├── __init__.py │ ├── README.md │ ├── utils.py │ ├── data.py │ ├── trainer.py │ └── models.py └── utils.py ├── GO4Align.png ├── methods ├── __init__.py ├── sdp_kmeans │ ├── __init__.py │ ├── utils.py │ ├── embedding.py │ ├── nmf.py │ └── sdp.py ├── min_norm_solvers.py ├── cluster_methods.py └── weight_methods.py ├── requirements.txt ├── setup.py └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/nyuv2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GO4Align.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autumn9999/GO4Align/HEAD/GO4Align.png -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from methods.weight_methods import ( 2 | METHODS, 3 | MGDA, 4 | STL, 5 | LinearScalarization, 6 | NashMTL, 7 | PCGrad, 8 | Uncertainty, 9 | ) 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.2.1 2 | numpy>=1.18.2 3 | torch>=1.4.0 4 | torchvision>=0.8.0 5 | cvxpy 6 | tqdm>=4.45.0 7 | pandas 8 | scikit-learn 9 | seaborn 10 | wandb==0.12.21 11 | plotly -------------------------------------------------------------------------------- /methods/sdp_kmeans/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | from .embedding import sdp_kmeans_embedding, spectral_embedding 3 | from .nmf import symnmf_admm 4 | from .sdp import sdp_kmeans, sdp_km, sdp_km_burer_monteiro,\ 5 | sdp_km_conditional_gradient 6 | from .utils import connected_components, dot_matrix -------------------------------------------------------------------------------- /methods/sdp_kmeans/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse.csgraph import connected_components as inner_conn_comp 3 | import pdb 4 | 5 | def dot_matrix(X): 6 | X_norm = X - np.mean(X, axis=0) 7 | X_norm /= np.max(np.linalg.norm(X, axis=1)) 8 | return X_norm.dot(X_norm.T) 9 | 10 | 11 | def connected_components(sym_mat, thresh=1e-4): 12 | binary_mat = sym_mat > sym_mat.max() * thresh 13 | n_comp, labels = inner_conn_comp(binary_mat, directed=False, 14 | return_labels=True) 15 | clusters = [labels == i for i in range(n_comp)] 16 | return clusters 17 | -------------------------------------------------------------------------------- /experiments/nyuv2/README.md: -------------------------------------------------------------------------------- 1 | # NYUv2 Experiment 2 | 3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad) and [MTAN](https://github.com/lorenmt/mtan). 4 | 5 | ## Dataset 6 | 7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `.dataset/train` and `./dataset/val`. 8 | 9 | ## Evaluation 10 | 11 | To align with previous work on MTL [Liu et al. (2019)](https://arxiv.org/abs/1803.10704); [Yu et al. (2020)](https://arxiv.org/abs/2001.06782); [Liu et al. (2021)](https://arxiv.org/pdf/2110.14048.pdf) we report the test performance averaged 12 | over the last 10 epochs. Note that this averaging is not handled in the code and need to be applied by the user. 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | from os import path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | here = path.abspath(path.dirname(__file__)) 7 | 8 | # get the long description from the README.md file 9 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 10 | long_description = f.read() 11 | 12 | 13 | # get reqs 14 | def requirements(): 15 | list_requirements = [] 16 | with open("requirements.txt") as f: 17 | for line in f: 18 | list_requirements.append(line.rstrip()) 19 | return list_requirements 20 | 21 | 22 | setup( 23 | name="nashmtl", 24 | version="1.0.0", # Required 25 | description="Nash-MTL", # Optional 26 | long_description="", # Optional 27 | long_description_content_type="text/markdown", # Optional (see note above) 28 | url="", # Optional 29 | author="", # Optional 30 | author_email="", # Optional 31 | packages=find_packages(exclude=["contrib", "docs", "tests"]), 32 | # packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required 33 | python_requires=">=3.6", 34 | install_requires=requirements(), # Optional 35 | ) 36 | -------------------------------------------------------------------------------- /methods/sdp_kmeans/embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import eigh 3 | from .sdp import sdp_kmeans 4 | 5 | 6 | def sdp_kmeans_embedding(X, n_clusters, target_dim, ret_sdp=False, 7 | method='cvx'): 8 | D, Q = sdp_kmeans(X, n_clusters, method=method) 9 | Y = spectral_embedding(Q, target_dim=target_dim, discard_first=True) 10 | if ret_sdp: 11 | return Y, D, Q 12 | else: 13 | return Y 14 | 15 | 16 | def spectral_embedding(mat, target_dim, gramian=True, discard_first=True): 17 | if discard_first: 18 | last = -1 19 | first = target_dim - last 20 | else: 21 | first = target_dim 22 | last = None 23 | if not gramian: 24 | mat = mat - mat.mean(axis=0) 25 | mat = mat.dot(mat.T) 26 | eigvals, eigvecs = eigh(mat) 27 | 28 | sl = slice(-first, last) 29 | eigvecs = eigvecs[:, sl] 30 | eigvals_crop = eigvals[sl] 31 | Y = eigvecs.dot(np.diag(np.sqrt(eigvals_crop))) 32 | Y = Y[:, ::-1] 33 | 34 | variance_explaned(eigvals, eigvals_crop) 35 | return Y 36 | 37 | 38 | def variance_explaned(eigvals, eigvals_crop): 39 | eigvals_crop[eigvals_crop < 0] = 0 40 | eigvals[eigvals < 0] = 0 41 | var = np.sum(eigvals_crop) / np.sum(eigvals) 42 | print('Variance explained:', var) 43 | -------------------------------------------------------------------------------- /methods/sdp_kmeans/nmf.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import numpy as np 3 | 4 | 5 | def symnmf_admm(A, k, H=None, maxiter=1e3, tol=1e-5, sigma=1): 6 | """ 7 | A is a symmetric matrix 8 | Solves || A - W.dot(W.T) ||_F^2 s.t. W >= 0 9 | """ 10 | A = A.copy() 11 | A[A < 0] = 0 12 | 13 | n = A.shape[0] 14 | if n != A.shape[1]: 15 | raise ValueError('A must be a symmetric matrix!') 16 | 17 | if H is None: 18 | H = np.sqrt(A.mean() / k) * np.random.randn(n, k) 19 | np.abs(H, H) 20 | 21 | Gamma = np.zeros((n, k)) 22 | id_k = np.identity(k) 23 | step = 1 24 | 25 | error = [] 26 | for i in range(int(maxiter)): 27 | temp = np.linalg.inv(H.T.dot(H) + sigma * id_k) 28 | W = (A.dot(H) + sigma * H - Gamma).dot(temp) 29 | W = np.maximum(W, 0) 30 | temp = np.linalg.inv(W.T.dot(W) + sigma * id_k) 31 | H = (A.dot(W) + sigma * W + Gamma).dot(temp) 32 | H = np.maximum(H, 0) 33 | Gamma += step * sigma * (W - H) 34 | 35 | error.append(np.linalg.norm(W - H) / np.linalg.norm(W)) 36 | if i > 0 and np.abs(error[-1]) < tol: 37 | break 38 | 39 | return W 40 | 41 | 42 | def symnmf_gram_admm(A, k, H=None, maxiter=1e3, tol=1e-5, sigma=1): 43 | """ 44 | Solves || A.dot(A.T) - W.dot(W.T) ||_F^2 s.t. W >= 0 45 | """ 46 | A = A.copy() 47 | A[A < 0] = 0 48 | 49 | if H is None: 50 | n = A.shape[0] 51 | H = np.sqrt(A.mean() / k) * np.random.randn(n, k) 52 | np.abs(H, H) 53 | 54 | Gamma = np.zeros((n, k)) 55 | id_k = np.identity(k) 56 | step = 1 57 | 58 | error = [] 59 | for i in range(int(maxiter)): 60 | temp = np.linalg.inv(H.T.dot(H) + sigma * id_k) 61 | W = (A.dot(A.T.dot(H)) + sigma * H - Gamma).dot(temp) 62 | W = np.maximum(W, 0) 63 | temp = np.linalg.inv(W.T.dot(W) + sigma * id_k) 64 | H = (A.dot(A.T.dot(W)) + sigma * W + Gamma).dot(temp) 65 | H = np.maximum(H, 0) 66 | Gamma += step * sigma * (W - H) 67 | 68 | error.append(np.linalg.norm(W - H) / np.linalg.norm(W)) 69 | if i > 0 and np.abs(error[-1]) < tol: 70 | break 71 | 72 | return W 73 | -------------------------------------------------------------------------------- /experiments/nyuv2/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ConfMatrix(object): 6 | def __init__(self, num_classes): 7 | self.num_classes = num_classes 8 | self.mat = None 9 | 10 | def update(self, pred, target): 11 | n = self.num_classes 12 | if self.mat is None: 13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 14 | with torch.no_grad(): 15 | k = (target >= 0) & (target < n) 16 | inds = n * target[k].to(torch.int64) + pred[k] 17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 18 | 19 | def get_metrics(self): 20 | h = self.mat.float() 21 | acc = torch.diag(h).sum() / h.sum() 22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy() 24 | 25 | 26 | def depth_error(x_pred, x_output): 27 | device = x_pred.device 28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device) 29 | x_pred_true = x_pred.masked_select(binary_mask) 30 | x_output_true = x_output.masked_select(binary_mask) 31 | abs_err = torch.abs(x_pred_true - x_output_true) 32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true 33 | return ( 34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 35 | ).item(), ( 36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 37 | ).item() 38 | 39 | 40 | def normal_error(x_pred, x_output): 41 | binary_mask = torch.sum(x_output, dim=1) != 0 42 | error = ( 43 | torch.acos( 44 | torch.clamp( 45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1 46 | ) 47 | ) 48 | .detach() 49 | .cpu() 50 | .numpy() 51 | ) 52 | error = np.degrees(error) 53 | return ( 54 | np.mean(error), 55 | np.median(error), 56 | np.mean(error < 11.25), 57 | np.mean(error < 22.5), 58 | np.mean(error < 30), 59 | ) 60 | 61 | 62 | # for calculating \Delta_m 63 | delta_stats = [ 64 | "mean iou", 65 | "pix acc", 66 | "abs err", 67 | "rel err", 68 | "mean", 69 | "median", 70 | "<11.25", 71 | "<22.5", 72 | "<30", 73 | ] 74 | BASE = np.array( 75 | [0.3830, 0.6376, 0.6754, 0.2780, 25.01, 19.21, 0.3014, 0.5720, 0.6915] 76 | ) # base results from CAGrad 77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1]) 78 | KK = np.ones(9) * -1 79 | 80 | 81 | def delta_fn(a): 82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage 83 | -------------------------------------------------------------------------------- /experiments/nyuv2/data.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataset import Dataset 9 | 10 | """Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py 11 | 12 | """ 13 | 14 | 15 | class RandomScaleCrop(object): 16 | """ 17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34. 18 | """ 19 | 20 | def __init__(self, scale=[1.0, 1.2, 1.5]): 21 | self.scale = scale 22 | 23 | def __call__(self, img, label, depth, normal): 24 | height, width = img.shape[-2:] 25 | sc = self.scale[random.randint(0, len(self.scale) - 1)] 26 | h, w = int(height / sc), int(width / sc) 27 | i = random.randint(0, height - h) 28 | j = random.randint(0, width - w) 29 | img_ = F.interpolate( 30 | img[None, :, i : i + h, j : j + w], 31 | size=(height, width), 32 | mode="bilinear", 33 | align_corners=True, 34 | ).squeeze(0) 35 | label_ = ( 36 | F.interpolate( 37 | label[None, None, i : i + h, j : j + w], 38 | size=(height, width), 39 | mode="nearest", 40 | ) 41 | .squeeze(0) 42 | .squeeze(0) 43 | ) 44 | depth_ = F.interpolate( 45 | depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest" 46 | ).squeeze(0) 47 | normal_ = F.interpolate( 48 | normal[None, :, i : i + h, j : j + w], 49 | size=(height, width), 50 | mode="bilinear", 51 | align_corners=True, 52 | ).squeeze(0) 53 | return img_, label_, depth_ / sc, normal_ 54 | 55 | 56 | class NYUv2(Dataset): 57 | """ 58 | We could further improve the performance with the data augmentation of NYUv2 defined in: 59 | [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing 60 | [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation 61 | [3] Mti-net: Multiscale task interaction networks for multi-task learning 62 | 1. Random scale in a selected raio 1.0, 1.2, and 1.5. 63 | 2. Random horizontal flip. 64 | Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper. 65 | """ 66 | 67 | def __init__(self, root, train=True, augmentation=False): 68 | self.train = train 69 | self.root = os.path.expanduser(root) 70 | self.augmentation = augmentation 71 | 72 | # read the data file 73 | if train: 74 | self.data_path = root + "/train" 75 | else: 76 | self.data_path = root + "/val" 77 | 78 | # calculate data length 79 | self.data_len = len( 80 | fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy") 81 | ) 82 | 83 | def __getitem__(self, index): 84 | # load data from the pre-processed npy files 85 | image = torch.from_numpy( 86 | np.moveaxis( 87 | np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0 88 | ) 89 | ) 90 | semantic = torch.from_numpy( 91 | np.load(self.data_path + "/label/{:d}.npy".format(index)) 92 | ) 93 | depth = torch.from_numpy( 94 | np.moveaxis( 95 | np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0 96 | ) 97 | ) 98 | normal = torch.from_numpy( 99 | np.moveaxis( 100 | np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0 101 | ) 102 | ) 103 | 104 | # apply data augmentation if required 105 | if self.augmentation: 106 | image, semantic, depth, normal = RandomScaleCrop()( 107 | image, semantic, depth, normal 108 | ) 109 | if torch.rand(1) < 0.5: 110 | image = torch.flip(image, dims=[2]) 111 | semantic = torch.flip(semantic, dims=[1]) 112 | depth = torch.flip(depth, dims=[2]) 113 | normal = torch.flip(normal, dims=[2]) 114 | normal[0, :, :] = -normal[0, :, :] 115 | 116 | return image.float(), semantic.float(), depth.float(), normal.float() 117 | 118 | def __len__(self): 119 | return self.data_len 120 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from methods import METHODS 11 | 12 | 13 | def list_of_float(arg): 14 | return list(map(float, arg.split(','))) 15 | 16 | def str_to_list(string): 17 | return [float(s) for s in string.split(",")] 18 | 19 | 20 | def str_or_float(value): 21 | try: 22 | return float(value) 23 | except: 24 | return value 25 | 26 | 27 | def str2bool(v): 28 | if isinstance(v, bool): 29 | return v 30 | if v.lower() in ("yes", "true", "t", "y", "1"): 31 | return True 32 | elif v.lower() in ("no", "false", "f", "n", "0"): 33 | return False 34 | else: 35 | raise argparse.ArgumentTypeError("Boolean value expected.") 36 | 37 | 38 | common_parser = argparse.ArgumentParser(add_help=False) 39 | common_parser.add_argument("--data-path", type=Path, help="path to data") 40 | common_parser.add_argument("--n-epochs", type=int, default=300) 41 | common_parser.add_argument("--batch-size", type=int, default=120, help="batch size") 42 | common_parser.add_argument("--method", type=str, choices=list(METHODS.keys()), help="MTL weight method") 43 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 44 | common_parser.add_argument( 45 | "--method-params-lr", 46 | type=float, 47 | default=0.025, 48 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting", 49 | ) 50 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID") 51 | common_parser.add_argument("--seed", type=int, default=42, help="seed value") 52 | # NashMTL 53 | common_parser.add_argument( 54 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations" 55 | ) 56 | common_parser.add_argument( 57 | "--update-weights-every", 58 | type=int, 59 | default=1, 60 | help="update task weights every x iterations.", 61 | ) 62 | # stl 63 | common_parser.add_argument( 64 | "--main-task", 65 | type=int, 66 | default=0, 67 | help="main task for stl. Ignored if method != stl", 68 | ) 69 | # cagrad 70 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.") 71 | # dwa 72 | # dwa 73 | common_parser.add_argument( 74 | "--dwa-temp", 75 | type=float, 76 | default=2.0, 77 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.", 78 | ) 79 | 80 | 81 | def count_parameters(model): 82 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | 84 | 85 | def set_logger(): 86 | logging.basicConfig( 87 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 88 | level=logging.INFO, 89 | ) 90 | 91 | 92 | def set_seed(seed): 93 | """for reproducibility 94 | :param seed: 95 | :return: 96 | """ 97 | np.random.seed(seed) 98 | random.seed(seed) 99 | 100 | torch.manual_seed(seed) 101 | if torch.cuda.is_available(): 102 | torch.cuda.manual_seed(seed) 103 | torch.cuda.manual_seed_all(seed) 104 | 105 | torch.backends.cudnn.enabled = True 106 | torch.backends.cudnn.benchmark = False 107 | torch.backends.cudnn.deterministic = True 108 | 109 | 110 | # def get_device(no_cuda=False, gpus="0"): 111 | # return torch.device( 112 | # f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu" 113 | # ) 114 | 115 | #------------------- 116 | def get_device(no_cuda=False, gpus=0): 117 | return torch.device( 118 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu" 119 | ) 120 | #------------------- 121 | 122 | def extract_weight_method_parameters_from_args(args): 123 | weight_methods_parameters = defaultdict(dict) 124 | weight_methods_parameters.update( 125 | dict( 126 | nashmtl=dict( update_weights_every=args.update_weights_every, optim_niter=args.nashmtl_optim_niter), 127 | stl=dict(main_task=args.main_task), 128 | cagrad=dict(c=args.c), 129 | dwa=dict(temp=args.dwa_temp), 130 | #------------- 131 | go4align=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size), 132 | group=dict(task_weights=args.task_weights, robust_step_size=args.robust_step_size), 133 | group_random=dict(task_weights=args.task_weights, robust_step_size=args.robust_step_size), 134 | #-------------rebuttal 135 | group_sklearn_spectral_clustering_cluster_qr=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size), 136 | group_sklearn_spectral_clustering_discretize=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size), 137 | group_sklearn_spectral_clustering_kmeans=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size), 138 | group_sdp_clustering=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size), 139 | ) 140 | ) 141 | return weight_methods_parameters 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GO4Align 2 | 3 | Welcome to the official repository for "GO4Align: Group Optimization for Multi-Task Alignment," one effective and efficient approach to multi-task optimization. 4 | 5 | 6 | ### Project Webpage 7 | 8 | Details will be available soon. 9 | 10 | ### Abstract 11 | 12 | This paper proposes GO4Align, a multi-task optimization approach that tackles task imbalance by explicitly aligning the optimization across tasks. 13 | To achieve this, we design an adaptive group risk minimization strategy, compromising two crucial techniques in implementation: 14 | - **dynamical group assignment**, which clusters similar tasks based on task interactions; 15 | - **risk-guided group indicators**, which exploit consistent task correlations with risk information from previous iterations. 16 | 17 | Comprehensive experimental results on diverse typical benchmarks demonstrate our method's performance superiority with even lower computational costs. 18 | 19 | ### Paper 20 | 21 | [The preprint of our paper](https://arxiv.org/abs/2404.06486) is available on arXiv. 22 | 23 | ### Framework of Adaptive Group Risk Minimization 24 | 25 | 26 |

27 | 28 |

29 | 30 | 31 | --- 32 | 33 | ### Setup Environment 34 | 35 | We recommend using miniconda to create a virtual environment for running the code: 36 | ```bash 37 | conda create -n go4align python=3.9.7 38 | conda activate go4align 39 | conda install pytorch==1.13.1 torchvision==0.14.1 cudatoolkit=12.3 -c pytorch 40 | conda install pyg -c pyg -c conda-forge 41 | ``` 42 | 43 | Install the package by running the following commands in the terminal: 44 | ```bash 45 | git clone https://github.com/autumn9999/GO4Align.git 46 | cd GO4Align 47 | pip install -e . 48 | ``` 49 | 50 | GPU: NVIDIA A100-SXM4-40GB 51 | 52 | ### Download Datasets 53 | 54 | This work is evaluated on several multi-task learning benchmarks: 55 | 56 | 1. [NYUv2](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0) (3 tasks), where the link is provided by the previous MTO work [CAGrad](https://github.com/Cranial-XIX/CAGrad.git). 57 | 2. [CityScapes](https://www.dropbox.com/sh/gaw6vh6qusoyms6/AADwWi0Tp3E3M4B2xzeGlsEna?dl=0) (2 tasks), where the link is provided by [CAGrad](https://github.com/Cranial-XIX/CAGrad.git). 58 | 3. [CelebA](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg) (40 tasks). Details can be found in the previous MTO work [FAMO](https://github.com/Cranial-XIX/FAMO). 59 | 4. QM9 (11 tasks), which can be downloaded automatically by `torch_geometric.datasets`. Details can be found in [FAMO](https://github.com/Cranial-XIX/FAMO/blob/main/experiments/quantum_chemistry/trainer.py). 60 | 61 | 62 | 63 | ### Run Experiments 64 | Here we provide experiments code for NYUv2. To run the experiment with other benchmark, please refer to unified APIs in [NashMTL](https://github.com/AvivNavon/nash-mtl) or [FAMO](https://github.com/Cranial-XIX/FAMO). 65 | 66 | ```bash 67 | cd experiment/nyuv2 68 | python trainer.py --method go4align 69 | ``` 70 | 71 | We also support the following MTL methods as alternatives. 72 | 73 | | Method (code name) | Paper (notes) | 74 | |:-------------------------:|:-------------------------------------------------------------------------------------------------------------------------------:| 75 | | Gradient-oriented methods | --------------------------------------------------------------------------------- | 76 | | MGDA | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/pdf/1810.04650) | 77 | | PCGrad | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) | 78 | | CAGrad | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) | 79 | | IMTL-G | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) | 80 | | NashMTL | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) | 81 | | Loss-oriented methods | --------------------------------------------------------------------------------- | 82 | | LS | - (equal weighting) | 83 | | SI | - (see Nash-MTL paper for details) | 84 | | RLW | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) | 85 | | DWA | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/pdf/1803.10704) | 86 | | UW | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) | 87 | | FAMO | [FAMO: Fast Adaptive Multitask Optimization](https://arxiv.org/pdf/2306.03792) | 88 | 89 | 90 | 91 | ### Citation 92 | This repo is built upon [NashMTL](https://github.com/AvivNavon/nash-mtl) or [FAMO](https://github.com/Cranial-XIX/FAMO). If our work **GO4Align** is helpful in your research or projects, please cite the following papers: 93 | 94 | ```bib 95 | @article{shen2024go4align, 96 | title={GO4Align: Group Optimization for Multi-Task Alignment}, 97 | author={Shen, Jiayi and Wang, Cheems and Xiao, Zehao and Van Noord, Nanne and Worring, Marcel}, 98 | journal={arXiv preprint arXiv:2404.06486}, 99 | year={2024} 100 | } 101 | 102 | @article{liu2024famo, 103 | title={Famo: Fast adaptive multitask optimization}, 104 | author={Liu, Bo and Feng, Yihao and Stone, Peter and Liu, Qiang}, 105 | journal={Advances in Neural Information Processing Systems}, 106 | volume={36}, 107 | year={2024} 108 | } 109 | 110 | @article{navon2022multi, 111 | title={Multi-task learning as a bargaining game}, 112 | author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan}, 113 | journal={arXiv preprint arXiv:2202.01017}, 114 | year={2022} 115 | } 116 | ``` 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /methods/min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | # This code is from 6 | # Multi-Task Learning as Multi-Objective Optimization 7 | # Ozan Sener, Vladlen Koltun 8 | # Neural Information Processing Systems (NeurIPS) 2018 9 | # https://github.com/intel-isl/MultiObjectiveOptimization 10 | class MinNormSolver: 11 | MAX_ITER = 250 12 | STOP_CRIT = 1e-5 13 | 14 | @staticmethod 15 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 16 | """ 17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 18 | d is the distance (objective) optimzed 19 | v1v1 = 20 | v1v2 = 21 | v2v2 = 22 | """ 23 | if v1v2 >= v1v1: 24 | # Case: Fig 1, third column 25 | gamma = 0.999 26 | cost = v1v1 27 | return gamma, cost 28 | if v1v2 >= v2v2: 29 | # Case: Fig 1, first column 30 | gamma = 0.001 31 | cost = v2v2 32 | return gamma, cost 33 | # Case: Fig 1, second column 34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2)) 35 | cost = v2v2 + gamma * (v1v2 - v2v2) 36 | return gamma, cost 37 | 38 | @staticmethod 39 | def _min_norm_2d(vecs, dps): 40 | """ 41 | Find the minimum norm solution as combination of two points 42 | This is correct only in 2D 43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j 44 | """ 45 | dmin = 1e8 46 | for i in range(len(vecs)): 47 | for j in range(i + 1, len(vecs)): 48 | if (i, j) not in dps: 49 | dps[(i, j)] = 0.0 50 | for k in range(len(vecs[i])): 51 | dps[(i, j)] += torch.dot( 52 | vecs[i][k], vecs[j][k] 53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0] 54 | dps[(j, i)] = dps[(i, j)] 55 | if (i, i) not in dps: 56 | dps[(i, i)] = 0.0 57 | for k in range(len(vecs[i])): 58 | dps[(i, i)] += torch.dot( 59 | vecs[i][k], vecs[i][k] 60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0] 61 | if (j, j) not in dps: 62 | dps[(j, j)] = 0.0 63 | for k in range(len(vecs[i])): 64 | dps[(j, j)] += torch.dot( 65 | vecs[j][k], vecs[j][k] 66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0] 67 | c, d = MinNormSolver._min_norm_element_from2( 68 | dps[(i, i)], dps[(i, j)], dps[(j, j)] 69 | ) 70 | if d < dmin: 71 | dmin = d 72 | sol = [(i, j), c, d] 73 | return sol, dps 74 | 75 | @staticmethod 76 | def _projection2simplex(y): 77 | """ 78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i 79 | """ 80 | m = len(y) 81 | sorted_y = np.flip(np.sort(y), axis=0) 82 | tmpsum = 0.0 83 | tmax_f = (np.sum(y) - 1.0) / m 84 | for i in range(m - 1): 85 | tmpsum += sorted_y[i] 86 | tmax = (tmpsum - 1) / (i + 1.0) 87 | if tmax > sorted_y[i + 1]: 88 | tmax_f = tmax 89 | break 90 | return np.maximum(y - tmax_f, np.zeros(y.shape)) 91 | 92 | @staticmethod 93 | def _next_point(cur_val, grad, n): 94 | proj_grad = grad - (np.sum(grad) / n) 95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0] 96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0]) 97 | 98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7) 99 | t = 1 100 | if len(tm1[tm1 > 1e-7]) > 0: 101 | t = np.min(tm1[tm1 > 1e-7]) 102 | if len(tm2[tm2 > 1e-7]) > 0: 103 | t = min(t, np.min(tm2[tm2 > 1e-7])) 104 | 105 | next_point = proj_grad * t + cur_val 106 | next_point = MinNormSolver._projection2simplex(next_point) 107 | return next_point 108 | 109 | @staticmethod 110 | def find_min_norm_element(vecs): 111 | """ 112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence 116 | """ 117 | # Solution lying at the combination of two points 118 | dps = {} 119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 120 | 121 | n = len(vecs) 122 | sol_vec = np.zeros(n) 123 | sol_vec[init_sol[0][0]] = init_sol[1] 124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 125 | 126 | if n < 3: 127 | # This is optimal for n=2, so return the solution 128 | return sol_vec, init_sol[2] 129 | 130 | iter_count = 0 131 | 132 | grad_mat = np.zeros((n, n)) 133 | for i in range(n): 134 | for j in range(n): 135 | grad_mat[i, j] = dps[(i, j)] 136 | 137 | while iter_count < MinNormSolver.MAX_ITER: 138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec) 139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n) 140 | # Re-compute the inner products for line search 141 | v1v1 = 0.0 142 | v1v2 = 0.0 143 | v2v2 = 0.0 144 | for i in range(n): 145 | for j in range(n): 146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)] 147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)] 148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)] 149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point 151 | change = new_sol_vec - sol_vec 152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 153 | return sol_vec, nd 154 | sol_vec = new_sol_vec 155 | 156 | @staticmethod 157 | def find_min_norm_element_FW(vecs): 158 | """ 159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence 163 | """ 164 | # Solution lying at the combination of two points 165 | dps = {} 166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 167 | 168 | n = len(vecs) 169 | sol_vec = np.zeros(n) 170 | sol_vec[init_sol[0][0]] = init_sol[1] 171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 172 | 173 | if n < 3: 174 | # This is optimal for n=2, so return the solution 175 | return sol_vec, init_sol[2] 176 | 177 | iter_count = 0 178 | 179 | grad_mat = np.zeros((n, n)) 180 | for i in range(n): 181 | for j in range(n): 182 | grad_mat[i, j] = dps[(i, j)] 183 | 184 | while iter_count < MinNormSolver.MAX_ITER: 185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec)) 186 | 187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec)) 188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter]) 189 | v2v2 = grad_mat[t_iter, t_iter] 190 | 191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 192 | new_sol_vec = nc * sol_vec 193 | new_sol_vec[t_iter] += 1 - nc 194 | 195 | change = new_sol_vec - sol_vec 196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 197 | return sol_vec, nd 198 | sol_vec = new_sol_vec 199 | 200 | 201 | def gradient_normalizers(grads, losses, normalization_type): 202 | gn = {} 203 | if normalization_type == "norm": 204 | for t in grads: 205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])) 206 | elif normalization_type == "loss": 207 | for t in grads: 208 | gn[t] = losses[t] 209 | elif normalization_type == "loss+": 210 | for t in grads: 211 | gn[t] = losses[t] * np.sqrt( 212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]) 213 | ) 214 | elif normalization_type == "none": 215 | for t in grads: 216 | gn[t] = 1.0 217 | else: 218 | print("ERROR: Invalid Normalization Type") 219 | return gn 220 | -------------------------------------------------------------------------------- /methods/sdp_kmeans/sdp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import cvxpy as cp 3 | from functools import partial 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from scipy.optimize import minimize 7 | from .nmf import symnmf_gram_admm, symnmf_admm 8 | from .utils import dot_matrix 9 | 10 | 11 | def sdp_kmeans(X, n_clusters, method='cvx'): 12 | if method == 'cvx': 13 | D = dot_matrix(X) 14 | Q = sdp_km(D, n_clusters) 15 | elif method == 'cgm': 16 | D = dot_matrix(X) 17 | Q = sdp_km_conditional_gradient(D, n_clusters) 18 | elif method == 'bm': 19 | Y = sdp_km_burer_monteiro(X, n_clusters) 20 | D = dot_matrix(X) 21 | Q = Y.dot(Y.T) 22 | else: 23 | raise ValueError('The method should be one of "cvx" and "bm"') 24 | 25 | return D, Q 26 | 27 | 28 | def sdp_km(D, n_clusters, max_iters=5000, eps=1e-5): 29 | ones = np.ones((D.shape[0], 1)) 30 | try: 31 | Z = cp.Variable(D.shape, PSD=True) 32 | except TypeError: 33 | Z = cp.Semidef(D.shape[0]) 34 | objective = cp.Maximize(cp.trace(D * Z)) 35 | constraints = [Z >= 0, 36 | Z * ones == ones, 37 | cp.trace(Z) == n_clusters] 38 | 39 | prob = cp.Problem(objective, constraints) 40 | prob.solve(solver=cp.SCS, verbose=False, max_iters=max_iters, eps=eps) 41 | 42 | Q = np.asarray(Z.value) 43 | rs = Q.sum(axis=1) 44 | print('Q', Q.min(), Q.max(), '|', 45 | rs.min(), rs.max(), '|', 46 | np.trace(Q), np.trace(D.dot(Q))) 47 | print('Final objective', np.trace(D.dot(Q))) 48 | 49 | return Q 50 | 51 | 52 | def sdp_km_burer_monteiro(X, n_clusters, rank=None, maxiter=1e3, tol=1e-5): 53 | if rank is None: 54 | rank = 8 * n_clusters 55 | 56 | X_norm = X - np.mean(X, axis=0) 57 | if X_norm.shape[0] > X_norm.shape[1]: 58 | cov = X_norm.T.dot(X_norm) 59 | else: 60 | XXt = X_norm.dot(X_norm.T) 61 | cov = XXt 62 | X_norm /= np.trace(cov.dot(cov)) ** 0.25 63 | 64 | if X_norm.shape[0] <= X_norm.shape[1]: 65 | XXt /= np.trace(cov.dot(cov)) ** 0.5 66 | 67 | Y_shape = (len(X), rank) 68 | ones = np.ones((len(X), 1)) 69 | 70 | def lagrangian(x, lambda1, lambda2, sigma1, sigma2): 71 | Y = x.reshape(Y_shape) 72 | 73 | if X_norm.shape[0] > X_norm.shape[1]: 74 | YtX = Y.T.dot(X_norm) 75 | obj = -np.trace(YtX.dot(YtX.T)) 76 | else: 77 | obj = -np.trace(Y.T.dot(XXt).dot(Y)) 78 | 79 | trYtY_minus_nclusters = np.trace(Y.T.dot(Y)) - n_clusters 80 | obj -= lambda1 * trYtY_minus_nclusters 81 | obj += .5 * sigma1 * trYtY_minus_nclusters ** 2 82 | 83 | YYt1_minus_1 = Y.dot(Y.T.dot(ones)) - ones 84 | obj -= lambda2.T.dot(YYt1_minus_1)[0, 0] 85 | obj += .5 * sigma2 * np.sum(YYt1_minus_1 ** 2) 86 | 87 | return obj 88 | 89 | def grad(x, lambda1, lambda2, sigma1, sigma2): 90 | Y = x.reshape(Y_shape) 91 | 92 | if X_norm.shape[0] > X_norm.shape[1]: 93 | delta = -2 * X_norm.dot(X_norm.T.dot(Y)) 94 | else: 95 | delta = -2 * XXt.dot(Y) 96 | 97 | YtY = Y.T.dot(Y) 98 | delta -= 2 * (lambda1 99 | -sigma1 * (np.trace(Y.T.dot(Y)) - n_clusters)) * Y 100 | 101 | delta -= ones.dot(lambda2.T.dot(Y)) + lambda2.dot(ones.T.dot(Y)) 102 | 103 | Yt1 = Y.T.dot(ones) 104 | delta += sigma2 * (Y.dot(Yt1).dot(Yt1.T) + ones.dot(Yt1.T).dot(YtY) 105 | -2 * ones.dot(Yt1.T)) 106 | 107 | return delta.flatten() 108 | 109 | if X_norm.shape[0] > X_norm.shape[1]: 110 | Y = symnmf_gram_admm(X_norm, rank) 111 | else: 112 | Y = symnmf_admm(XXt, rank) 113 | 114 | lambda1 = 0. 115 | lambda2 = np.zeros((len(X), 1)) 116 | sigma1 = 1 117 | sigma2 = 1 118 | step = 1 119 | 120 | error = [] 121 | for n_iter in range(int(maxiter)): 122 | fun = partial(lagrangian, lambda1=lambda1, lambda2=lambda2, 123 | sigma1=sigma1, sigma2=sigma2) 124 | jac = partial(grad, lambda1=lambda1, lambda2=lambda2, 125 | sigma1=sigma1, sigma2=sigma2) 126 | bounds = [(0, 1)] * np.prod(Y_shape) 127 | 128 | Y_old = Y.copy() 129 | res = minimize(fun, Y.flatten(), jac=jac, bounds=bounds, 130 | method='L-BFGS-B',) 131 | Y = res.x.reshape(Y_shape) 132 | 133 | lambda1 -= step * sigma1 * (np.trace(Y.T.dot(Y)) - n_clusters) 134 | lambda2 -= step * sigma2 * (Y.dot(Y.T.dot(ones)) - ones) 135 | 136 | error.append(np.linalg.norm(Y - Y_old) / np.linalg.norm(Y_old)) 137 | 138 | if error[-1] < tol: 139 | break 140 | 141 | return Y 142 | 143 | 144 | def sdp_km_conditional_gradient(D, n_clusters, max_iter=2e3, 145 | stop_tol_max=1e-3, stop_tol_rmse=1e-4, 146 | n_inner_iter=15, 147 | use_line_search=False, 148 | verbose=False, track_stats=False): 149 | n = len(D) 150 | one_over_n = 1. / n 151 | 152 | def lagrangian(Q, lagrange_lower_bound, t): 153 | obj = -np.sum(D * Q) 154 | obj += np.sum(lagrange_lower_bound * (Q + one_over_n)) 155 | penalty = (t + 1) ** 0.5 156 | obj += penalty * np.sum(np.minimum(Q + one_over_n, 0) ** 2) 157 | return obj 158 | 159 | def gradient(Q, lagrange_lower_bound, t): 160 | delta = -D 161 | 162 | delta += lagrange_lower_bound 163 | penalty = (t + 1) ** 0.5 164 | delta += penalty * np.minimum(Q + one_over_n, 0) 165 | return delta 166 | 167 | def solve_lp(grad, t): 168 | # The following commented 2 lines are the mathematically 169 | # friendly version of the 2 lines following them. 170 | # ortho_mat = np.eye(n) - np.ones_like(D) / n 171 | # A = ortho_mat.dot(grad).dot(ortho_mat) 172 | grad11 = np.broadcast_to(grad.sum(axis=1) / n, (n, n)) 173 | A = grad - grad11 - grad11.T + grad.sum() / (n ** 2) 174 | 175 | tol = (t + 1) ** -1 176 | return sp.linalg.eigsh(A, k=1, which='SA', tol=tol) 177 | 178 | def line_search(update, current, lagrange_lower_bound, t, tol=1e-5): 179 | gr = (np.sqrt(5) + 1) / 2 180 | a = 0 181 | b = 1 182 | 183 | c = b - (b - a) / gr 184 | d = a + (b - a) / gr 185 | while abs(c - d) > tol: 186 | convex_comb_c = c * update + (1 - c) * current 187 | convex_comb_d = d * update + (1 - d) * current 188 | fc = lagrangian(convex_comb_c, lagrange_lower_bound, t) 189 | fd = lagrangian(convex_comb_d, lagrange_lower_bound, t) 190 | if fc < fd: 191 | b = d 192 | else: 193 | a = c 194 | 195 | # we recompute both c and d here to avoid loss of precision 196 | # which may lead to incorrect results or infinite loop 197 | c = b - (b - a) / gr 198 | d = a + (b - a) / gr 199 | 200 | return (b + a) / 2 201 | 202 | 203 | if track_stats or verbose: 204 | rmse_list = [] 205 | obj_value_list = [] 206 | 207 | Q = np.zeros_like(D) 208 | lagrange_lower_bound = np.zeros(D.shape) 209 | step = 1 210 | 211 | for t in range(int(max_iter)): 212 | for inner_it in range(n_inner_iter): 213 | grad = gradient(Q, lagrange_lower_bound, t) 214 | s, v, = solve_lp(grad, inner_it) 215 | 216 | if s < 0: 217 | update = (n_clusters - 1) * np.outer(v, v) 218 | if use_line_search: 219 | eta = line_search(update, Q, lagrange_lower_bound, 220 | t * n_inner_iter + inner_it) 221 | else: 222 | eta = 2. / (t * n_inner_iter + inner_it + 2) 223 | Q = (1 - eta) * Q + eta * update 224 | 225 | Q_nneg = Q + one_over_n 226 | 227 | rmse = np.sqrt(np.mean(Q_nneg[Q_nneg < 0] ** 2)) 228 | max_error = np.abs(np.min(Q_nneg[Q_nneg < 0])) / (n_clusters / n) 229 | 230 | if track_stats or verbose: 231 | obj_value_list.append(np.trace(D.dot(Q))) 232 | rmse_list.append(rmse) 233 | 234 | lagrange_lower_bound += step * Q_nneg 235 | np.minimum(lagrange_lower_bound, 0, out=lagrange_lower_bound) 236 | 237 | if verbose and t % 10 == 0: 238 | row_sum = Q.sum(axis=1) 239 | print('iteration', t, '|', 240 | -one_over_n, Q.min(), rmse, max_error, '|', 241 | row_sum.min(), row_sum.max(), '|', 242 | np.trace(Q), np.trace(D.dot(Q)), '|', 243 | eta) 244 | 245 | if max_error < stop_tol_max and rmse < stop_tol_rmse: 246 | if verbose: 247 | row_sum = Q.sum(axis=1) 248 | print('iteration', t, '|', 249 | -one_over_n, Q.min(), rmse, max_error, '|', 250 | row_sum.min(), row_sum.max(), '|', 251 | np.trace(Q), np.trace(D.dot(Q)), '|', 252 | eta) 253 | break 254 | 255 | Q += one_over_n 256 | 257 | if verbose: 258 | row_sum = np.mean(Q.sum(axis=1)) 259 | print('sum constraint', row_sum.min(), row_sum.max()) 260 | print('trace constraint', np.trace(Q)) 261 | print('nonnegative constraint', np.min(Q), np.mean(np.minimum(Q, 0))) 262 | 263 | print('final objective', np.trace(D.dot(Q))) 264 | 265 | if track_stats: 266 | return Q, rmse_list, obj_value_list 267 | else: 268 | return Q 269 | -------------------------------------------------------------------------------- /experiments/nyuv2/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import wandb 3 | from argparse import ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from tqdm import trange 10 | 11 | from experiments.nyuv2.data import NYUv2 12 | from experiments.nyuv2.models import SegNet, SegNetMtan 13 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error 14 | from experiments.utils import ( 15 | common_parser, 16 | extract_weight_method_parameters_from_args, 17 | get_device, 18 | set_logger, 19 | set_seed, 20 | str2bool, 21 | list_of_float, 22 | ) 23 | from methods.weight_methods import WeightMethods 24 | 25 | #----------------- 26 | import pdb 27 | import os 28 | import time 29 | import sys 30 | #----------------- 31 | 32 | set_logger() 33 | 34 | #----------------- 35 | def log_string(file_out, out_str, print_out=True): 36 | file_out.write(out_str+'\n') 37 | file_out.flush() 38 | if print_out: 39 | print(out_str) 40 | # ----------------- 41 | 42 | def calc_loss(x_pred, x_output, task_type): 43 | device = x_pred.device 44 | 45 | # binary mark to mask out undefined pixel space 46 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device) 47 | 48 | if task_type == "semantic": 49 | # semantic loss: depth-wise cross entropy 50 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1) 51 | 52 | if task_type == "depth": 53 | # depth loss: l1 norm 54 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero( 55 | binary_mask, as_tuple=False 56 | ).size(0) 57 | 58 | if task_type == "normal": 59 | # normal loss: dot product 60 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero( 61 | binary_mask, as_tuple=False 62 | ).size(0) 63 | 64 | return loss 65 | 66 | 67 | def main(args, file_out, file_weight_out, file_loss_before, file_loss_after, path, lr, bs, device): 68 | # ---- 69 | # Nets 70 | # --- 71 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model] 72 | model = model.to(device) 73 | 74 | # dataset and dataloaders 75 | log_str = ( 76 | "Applying data augmentation on NYUv2." 77 | if args.apply_augmentation 78 | else "Standard training strategy without data augmentation." 79 | ) 80 | logging.info(log_str) 81 | 82 | nyuv2_train_set = NYUv2( 83 | root=path.as_posix(), train=True, augmentation=args.apply_augmentation 84 | ) 85 | nyuv2_test_set = NYUv2(root=path.as_posix(), train=False) 86 | 87 | train_loader = torch.utils.data.DataLoader( 88 | dataset=nyuv2_train_set, batch_size=bs, shuffle=True 89 | ) 90 | 91 | test_loader = torch.utils.data.DataLoader( 92 | dataset=nyuv2_test_set, batch_size=bs, shuffle=False 93 | ) 94 | 95 | # weight method 96 | weight_methods_parameters = extract_weight_method_parameters_from_args(args) 97 | weight_method = WeightMethods(args.method, n_tasks=3, device=device, **weight_methods_parameters[args.method]) 98 | 99 | # optimizer 100 | optimizer = torch.optim.Adam( 101 | [ 102 | dict(params=model.parameters(), lr=lr), 103 | dict(params=weight_method.parameters(), lr=args.method_params_lr), 104 | ], 105 | ) 106 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 107 | 108 | epochs = args.n_epochs 109 | epoch_iter = trange(epochs) 110 | train_batch = len(train_loader) 111 | test_batch = len(test_loader) 112 | avg_cost = np.zeros([epochs, 24], dtype=np.float32) 113 | # ----------------- 114 | avg_overall_performance = np.zeros([epochs, 2], dtype=np.float32) 115 | # ----------------- 116 | custom_step = -1 117 | conf_mat = ConfMatrix(model.segnet.class_nb) 118 | 119 | # training 120 | for epoch in epoch_iter: 121 | # ---- 122 | start = time.time() 123 | # ---- 124 | cost = np.zeros(24, dtype=np.float32) 125 | 126 | for j, batch in enumerate(train_loader): 127 | #----------------- 128 | # pdb.set_trace() 129 | if args.debugging: 130 | if j == 2: 131 | break 132 | # ----------------- 133 | custom_step += 1 134 | model.train() 135 | optimizer.zero_grad() 136 | 137 | train_data, train_label, train_depth, train_normal = batch 138 | train_data, train_label = train_data.to(device), train_label.long().to( 139 | device 140 | ) 141 | train_depth, train_normal = train_depth.to(device), train_normal.to(device) 142 | 143 | train_pred, features = model(train_data, return_representation=True) 144 | 145 | losses = torch.stack( 146 | ( 147 | calc_loss(train_pred[0], train_label, "semantic"), 148 | calc_loss(train_pred[1], train_depth, "depth"), 149 | calc_loss(train_pred[2], train_normal, "normal"), 150 | ) 151 | ) 152 | 153 | # pdb.set_trace() 154 | loss, extra_outputs = weight_method.backward( 155 | losses=losses, 156 | shared_parameters=list(model.shared_parameters()), 157 | task_specific_parameters=list(model.task_specific_parameters()), 158 | last_shared_parameters=list(model.last_shared_parameters()), 159 | representation=features, 160 | ) 161 | # pdb.set_trace() 162 | 163 | optimizer.step() 164 | 165 | #----------updated from FAMO 166 | if "famo" in args.method: 167 | with torch.no_grad(): 168 | train_pred = model(train_data, return_representation=False) 169 | new_losses = torch.stack( 170 | ( 171 | calc_loss(train_pred[0], train_label, "semantic"), 172 | calc_loss(train_pred[1], train_depth, "depth"), 173 | calc_loss(train_pred[2], train_normal, "normal"), 174 | ) 175 | ) 176 | weight_method.method.update(new_losses.detach()) 177 | #---------- 178 | 179 | # accumulate label prediction for every pixel in training images 180 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten()) 181 | 182 | cost[0] = losses[0].item() 183 | cost[3] = losses[1].item() 184 | cost[4], cost[5] = depth_error(train_pred[1], train_depth) 185 | cost[6] = losses[2].item() 186 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal) 187 | avg_cost[epoch, :12] += cost[:12] / train_batch 188 | 189 | epoch_iter.set_description( 190 | f"[{epoch+1} {j+1}/{train_batch}] semantic loss: {losses[0].item():.3f}, " 191 | f"depth loss: {losses[1].item():.3f}, " 192 | f"normal loss: {losses[2].item():.3f}" 193 | ) 194 | 195 | # scheduler 196 | scheduler.step() 197 | 198 | # compute mIoU and acc 199 | avg_cost[epoch, 1:3] = conf_mat.get_metrics() 200 | 201 | # todo: move evaluate to function? 202 | # evaluating test data 203 | model.eval() 204 | conf_mat = ConfMatrix(model.segnet.class_nb) 205 | with torch.no_grad(): # operations inside don't track history 206 | test_dataset = iter(test_loader) 207 | for k in range(test_batch): 208 | # ----------------- 209 | # pdb.set_trace() 210 | if args.debugging: 211 | if k == 2: 212 | break 213 | # ----------------- 214 | test_data, test_label, test_depth, test_normal = next(test_dataset) #test_dataset.next()#.next is deprecated 215 | test_data, test_label = test_data.to(device), test_label.long().to( 216 | device 217 | ) 218 | test_depth, test_normal = test_depth.to(device), test_normal.to(device) 219 | 220 | test_pred = model(test_data) 221 | test_loss = torch.stack( 222 | ( 223 | calc_loss(test_pred[0], test_label, "semantic"), 224 | calc_loss(test_pred[1], test_depth, "depth"), 225 | calc_loss(test_pred[2], test_normal, "normal"), 226 | ) 227 | ) 228 | 229 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten()) 230 | 231 | cost[12] = test_loss[0].item() 232 | cost[15] = test_loss[1].item() 233 | cost[16], cost[17] = depth_error(test_pred[1], test_depth) 234 | cost[18] = test_loss[2].item() 235 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error( 236 | test_pred[2], test_normal 237 | ) 238 | avg_cost[epoch, 12:] += cost[12:] / test_batch 239 | # compute mIoU and acc 240 | avg_cost[epoch, 13:15] = conf_mat.get_metrics() 241 | 242 | # Test Delta_m 243 | test_delta_m = delta_fn( 244 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]] 245 | ) 246 | 247 | # ----------------- 248 | avg_overall_performance[epoch, 0] = test_delta_m 249 | avg_overall_performance[epoch, 1] = avg_overall_performance[:epoch+1, 0].min() 250 | # ----------------- 251 | 252 | # print results 253 | # print( 254 | # f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR " 255 | # f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m (test)" 256 | # ) 257 | # print("") 258 | 259 | # ---- 260 | end = time.time() 261 | log_string(file_out, 'Epoch: {:04d} | Epoch time {:.4f}'.format(epoch, end - start)) 262 | # if args.method in {"famo"}: 263 | log_string(file_weight_out, 'Epoch {:04d} | {}'.format(epoch, extra_outputs["weights"]), False) 264 | log_string(file_loss_before, 'Epoch {:04d} | {}'.format(epoch, losses.detach().cpu()), False) 265 | log_string(file_loss_after, 'Epoch {:04d} | {}'.format(epoch, loss.detach().cpu()), False) 266 | # ---- 267 | 268 | log_string(file_out, f"Epoch: {epoch:04d} | TRAIN: " 269 | f"{avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]*100:.2f} {avg_cost[epoch, 2]*100:.2f} | " 270 | f"{avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | " 271 | f"{avg_cost[epoch, 6]:.4f} {avg_cost[epoch, 7]:.2f} {avg_cost[epoch, 8]:.2f} {avg_cost[epoch, 9]*100:.2f} {avg_cost[epoch, 10]*100:.2f} {avg_cost[epoch, 11]*100:.2f} || " 272 | f"TEST: " 273 | f"{avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]*100:.2f} {avg_cost[epoch, 14]*100:.2f} | " 274 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | " 275 | f"{avg_cost[epoch, 18]:.4f} {avg_cost[epoch, 19]:.2f} {avg_cost[epoch, 20]:.2f} {avg_cost[epoch, 21]*100:.2f} {avg_cost[epoch, 22]*100:.2f} {avg_cost[epoch, 23]*100:.2f} " 276 | f"| {test_delta_m:.3f} {avg_overall_performance[epoch, -1]:.3f}" 277 | ) 278 | 279 | if wandb.run is not None: 280 | wandb.log({"Train Semantic Loss": avg_cost[epoch, 0]}, step=epoch) 281 | wandb.log({"Train Mean IoU": avg_cost[epoch, 1]}, step=epoch) 282 | wandb.log({"Train Pixel Accuracy": avg_cost[epoch, 2]}, step=epoch) 283 | wandb.log({"Train Depth Loss": avg_cost[epoch, 3]}, step=epoch) 284 | wandb.log({"Train Absolute Error": avg_cost[epoch, 4]}, step=epoch) 285 | wandb.log({"Train Relative Error": avg_cost[epoch, 5]}, step=epoch) 286 | wandb.log({"Train Normal Loss": avg_cost[epoch, 6]}, step=epoch) 287 | wandb.log({"Train Loss Mean": avg_cost[epoch, 7]}, step=epoch) 288 | wandb.log({"Train Loss Med": avg_cost[epoch, 8]}, step=epoch) 289 | wandb.log({"Train Loss <11.25": avg_cost[epoch, 9]}, step=epoch) 290 | wandb.log({"Train Loss <22.5": avg_cost[epoch, 10]}, step=epoch) 291 | wandb.log({"Train Loss <30": avg_cost[epoch, 11]}, step=epoch) 292 | 293 | wandb.log({"Test Semantic Loss": avg_cost[epoch, 12]}, step=epoch) 294 | wandb.log({"Test Mean IoU": avg_cost[epoch, 13]}, step=epoch) 295 | wandb.log({"Test Pixel Accuracy": avg_cost[epoch, 14]}, step=epoch) 296 | wandb.log({"Test Depth Loss": avg_cost[epoch, 15]}, step=epoch) 297 | wandb.log({"Test Absolute Error": avg_cost[epoch, 16]}, step=epoch) 298 | wandb.log({"Test Relative Error": avg_cost[epoch, 17]}, step=epoch) 299 | wandb.log({"Test Normal Loss": avg_cost[epoch, 18]}, step=epoch) 300 | wandb.log({"Test Loss Mean": avg_cost[epoch, 19]}, step=epoch) 301 | wandb.log({"Test Loss Med": avg_cost[epoch, 20]}, step=epoch) 302 | wandb.log({"Test Loss <11.25": avg_cost[epoch, 21]}, step=epoch) 303 | wandb.log({"Test Loss <22.5": avg_cost[epoch, 22]}, step=epoch) 304 | wandb.log({"Test Loss <30": avg_cost[epoch, 23]}, step=epoch) 305 | wandb.log({"Test ∆m": test_delta_m}, step=epoch) 306 | 307 | # #------------------------------------------- 308 | # wandb.log({"Weight_seg": extra_outputs["weights"][0].item()}, step=epoch) 309 | # wandb.log({"Weight_depth": extra_outputs["weights"][1].item()}, step=epoch) 310 | # wandb.log({"Weight_normal": extra_outputs["weights"][2].item()}, step=epoch) 311 | # # ------------------------------------------- 312 | 313 | # ----------------- 314 | # final output 315 | log_string(file_out, f"FORMAT: MEAN_IOU PIX_ACC | ABS_ERR REL_ERR | MEAN MED <11.25 <22.5 <30 | ∆m (test)") 316 | log_string(file_out, 317 | f"Last_10_TEST: " 318 | f"{avg_cost[-10:, 13].mean()*100:.2f} {avg_cost[-10:, 14].mean()*100:.2f} " 319 | f"{avg_cost[-10:, 16].mean():.4f} {avg_cost[-10:, 17].mean():.4f} " 320 | f"{avg_cost[-10:, 19].mean():.2f} {avg_cost[-10:, 20].mean():.2f} {avg_cost[-10:, 21].mean()*100:.2f} {avg_cost[-10:, 22].mean()*100:.2f} {avg_cost[-10:, 23].mean()*100:.2f} " 321 | f"{avg_overall_performance[-10:, 0].mean():.3f}" 322 | ) 323 | # ----------------- 324 | 325 | # pdb.set_trace() 326 | 327 | if __name__ == "__main__": 328 | parser = ArgumentParser("NYUv2", parents=[common_parser]) 329 | parser.set_defaults( 330 | data_path = "../../../../../../../dataset/SIMO/nyuv2", 331 | lr=1e-4, 332 | n_epochs=200, 333 | batch_size=2, 334 | ) 335 | 336 | parser.add_argument("--model", type=str, default="mtan", choices=["segnet", "mtan"], help="model type") 337 | parser.add_argument("--apply-augmentation", type=str2bool, default=True, help="data augmentations") 338 | # parser.add_argument("--wandb_project", type=str, default="nashmtl_nyuv2", help="Name of Weights & Biases Project.") 339 | # parser.add_argument("--wandb_entity", type=str, default="jia-yi9999", help="Name of Weights & Biases Entity.") 340 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.") 341 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.") 342 | # ----------------- 343 | parser.add_argument('--debugging', action='store_true', help='with debugging') 344 | parser.add_argument('--log_name', default="logs_segnet_mtan/log", type=str, help='log name') 345 | parser.add_argument('--robust_step_size', default=0.0001, type=float, help='for our method') 346 | parser.add_argument('--task_weights', default="0.1,0.1,0.8", type=list_of_float, help='for group') 347 | parser.add_argument('--num_groups', default=1, type=int, help='number of groups') 348 | # ----------------- 349 | args = parser.parse_args() 350 | 351 | # set seed 352 | set_seed(args.seed) 353 | #----------------- 354 | if args.method == "go4align": 355 | log_name = args.log_name + '_Method={}_Ngroups={}_Seed={}'.format(args.method, args.num_groups, args.seed) 356 | elif args.method == "group": 357 | log_name = args.log_name + '_Method={}_{}_{}_{}_Seed={}'.format(args.method, args.task_weights[0], args.task_weights[1], args.task_weights[2], args.seed) 358 | else: 359 | log_name = args.log_name + '_Method={}_Seed={}'.format(args.method, args.seed) 360 | 361 | os.system("mkdir -p " + log_name) 362 | file_out = open(log_name + "/train_log.txt", "w") 363 | file_weight_out = open(log_name + "/weight_log.txt", "w") 364 | file_loss_before = open(log_name + "/loss_before.txt", "w") 365 | file_loss_after = open(log_name+ "/loss_after.txt", "w") 366 | 367 | os.system("mkdir -p " + log_name + '/files') 368 | os.system('cp %s %s' % ('*.py', os.path.join(log_name, 'files'))) 369 | print(get_device(gpus=args.gpu)) 370 | #-------------------- 371 | 372 | if args.wandb_project is not None: 373 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=log_name, config=args) 374 | 375 | device = get_device(gpus=args.gpu) 376 | main(args, file_out, file_weight_out, file_loss_before, file_loss_after, path=args.data_path, lr=args.lr, bs=args.batch_size, device=device) 377 | 378 | if wandb.run is not None: 379 | wandb.finish() 380 | 381 | # ----------------- 382 | st = ' ' 383 | log_string(file_out, st.join(sys.argv)) 384 | file_out.close() 385 | #-------------------- -------------------------------------------------------------------------------- /methods/cluster_methods.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from abc import abstractmethod 4 | from typing import Dict, List, Tuple, Union 5 | 6 | import cvxpy as cp 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from scipy.optimize import minimize 11 | 12 | # from kmeans_pytorch import kmeans # training on gpus 13 | from sklearn.cluster import KMeans, SpectralClustering # training on cpu 14 | # assign labels after the Laplacian embedding. 15 | # The cluster_qr method [5] directly extract clusters from eigenvectors in spectral clustering. 16 | # Simple, direct, and efficient multi-way spectral clustering, 2019 Anil Damle, Victor Minden, Lexing Ying 17 | 18 | from .sdp_kmeans import sdp_kmeans, connected_components, spectral_embedding 19 | # https://github.com/simonsfoundation/sdp_kmeans.git 20 | # Mariano Tepper, Anirvan Sengupta, Dmitri Chklovskii, The surprising secret identity of the semidefinite relaxation of K-means: manifold learning, 2017 21 | 22 | 23 | class WeightMethod: 24 | def __init__(self, n_tasks: int, device: torch.device): 25 | super().__init__() 26 | self.n_tasks = n_tasks 27 | self.device = device 28 | 29 | @abstractmethod 30 | def get_weighted_loss( 31 | self, 32 | losses: torch.Tensor, 33 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 34 | task_specific_parameters: Union[ 35 | List[torch.nn.parameter.Parameter], torch.Tensor 36 | ], 37 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 38 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 39 | **kwargs, 40 | ): 41 | pass 42 | 43 | def backward( 44 | self, 45 | losses: torch.Tensor, 46 | shared_parameters: Union[ 47 | List[torch.nn.parameter.Parameter], torch.Tensor 48 | ] = None, 49 | task_specific_parameters: Union[ 50 | List[torch.nn.parameter.Parameter], torch.Tensor 51 | ] = None, 52 | last_shared_parameters: Union[ 53 | List[torch.nn.parameter.Parameter], torch.Tensor 54 | ] = None, 55 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 56 | **kwargs, 57 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 58 | """ 59 | 60 | Parameters 61 | ---------- 62 | losses : 63 | shared_parameters : 64 | task_specific_parameters : 65 | last_shared_parameters : parameters of last shared layer/block 66 | representation : shared representation 67 | kwargs : 68 | 69 | Returns 70 | ------- 71 | Loss, extra outputs 72 | """ 73 | loss, extra_outputs = self.get_weighted_loss( 74 | losses=losses, 75 | shared_parameters=shared_parameters, 76 | task_specific_parameters=task_specific_parameters, 77 | last_shared_parameters=last_shared_parameters, 78 | representation=representation, 79 | **kwargs, 80 | ) 81 | loss.backward() 82 | return loss, extra_outputs 83 | 84 | def __call__( 85 | self, 86 | losses: torch.Tensor, 87 | shared_parameters: Union[ 88 | List[torch.nn.parameter.Parameter], torch.Tensor 89 | ] = None, 90 | task_specific_parameters: Union[ 91 | List[torch.nn.parameter.Parameter], torch.Tensor 92 | ] = None, 93 | **kwargs, 94 | ): 95 | return self.backward( 96 | losses=losses, 97 | shared_parameters=shared_parameters, 98 | task_specific_parameters=task_specific_parameters, 99 | **kwargs, 100 | ) 101 | 102 | def parameters(self) -> List[torch.Tensor]: 103 | """return learnable parameters""" 104 | return [] 105 | 106 | class GO4ALIGN(WeightMethod): 107 | def __init__( 108 | self, 109 | n_tasks: int, 110 | device: torch.device, 111 | num_groups: int, 112 | task_weights: Union[List[float], torch.Tensor] = None, 113 | robust_step_size: float = 0.0001, 114 | ): 115 | super().__init__(n_tasks, device=device) 116 | self.n_tasks = n_tasks 117 | self.device = device 118 | if task_weights is None: 119 | task_weights = torch.ones((n_tasks,)) 120 | if not isinstance(task_weights, torch.Tensor): 121 | task_weights = torch.tensor(task_weights) 122 | assert len(task_weights) == n_tasks 123 | self.task_weights = task_weights.to(device) 124 | 125 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks 126 | self.robust_step_size = robust_step_size 127 | self.num_groups = num_groups 128 | 129 | def get_weighted_loss(self, losses, **kwargs): 130 | adjusted_loss = losses.detach() 131 | scale = adjusted_loss.sum() / adjusted_loss 132 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss) 133 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 134 | weight = scale * self.adv_probs 135 | 136 | weight = weight.unsqueeze(1) 137 | 138 | if self.num_groups >=2: 139 | cluster_ids_x, cluster_centers = kmeans(X=weight, num_clusters=self.num_groups, distance='euclidean', device=self.device) 140 | mask = torch.zeros(self.n_tasks, self.num_groups).to(self.device) 141 | cluster_ids = cluster_ids_x.unsqueeze(1).to(self.device) 142 | cluster_centers = cluster_centers.to(self.device) 143 | mask.scatter_(1, cluster_ids, 1) 144 | kmeans_weight = torch.mm(mask,cluster_centers).squeeze(1) 145 | 146 | elif self.num_groups == 1: 147 | kmeans_weight = torch.ones(self.n_tasks).to(self.device) 148 | kmeans_weight = kmeans_weight * torch.mean(weight) 149 | 150 | loss = torch.sum(losses * kmeans_weight) 151 | return loss, dict(weights=torch.cat([kmeans_weight])) 152 | 153 | class GROUP(WeightMethod): 154 | def __init__( 155 | self, 156 | n_tasks: int, 157 | device: torch.device, 158 | task_weights: Union[List[float], torch.Tensor] = None, 159 | robust_step_size: float = 0.0001, 160 | ): 161 | super().__init__(n_tasks, device=device) 162 | self.n_tasks = n_tasks 163 | if task_weights is None: 164 | task_weights = torch.ones((n_tasks,)) 165 | if not isinstance(task_weights, torch.Tensor): 166 | task_weights = torch.tensor(task_weights) 167 | assert len(task_weights) == n_tasks 168 | self.task_weights = task_weights.to(device) 169 | 170 | def get_weighted_loss(self, losses, **kwargs): 171 | loss = torch.sum(losses * self.task_weights) * self.n_tasks 172 | return loss, dict(weights=torch.cat([self.task_weights])) 173 | 174 | class GROUP_RANDOM(WeightMethod): 175 | def __init__( 176 | self, 177 | n_tasks: int, 178 | device: torch.device, 179 | task_weights: Union[List[float], torch.Tensor] = None, 180 | robust_step_size: float = 0.0001, 181 | ): 182 | super().__init__(n_tasks, device=device) 183 | self.n_tasks = n_tasks 184 | if task_weights is None: 185 | task_weights = torch.ones((n_tasks,)) 186 | if not isinstance(task_weights, torch.Tensor): 187 | task_weights = torch.tensor(task_weights) 188 | assert len(task_weights) == n_tasks 189 | self.task_weights = task_weights.to(device) 190 | 191 | def get_weighted_loss(self, losses, **kwargs): 192 | idx = torch.randperm(self.task_weights.shape[0]) 193 | weights = self.task_weights[idx] 194 | loss = torch.sum(losses * weights) * self.n_tasks 195 | return loss, dict(weights=torch.cat([weights])) 196 | 197 | class GROUP_sklearn_spectral_clustering_cluster_qr(WeightMethod): 198 | def __init__( 199 | self, 200 | n_tasks: int, 201 | device: torch.device, 202 | num_groups: int, 203 | task_weights: Union[List[float], torch.Tensor] = None, 204 | robust_step_size: float = 0.0001, 205 | ): 206 | super().__init__(n_tasks, device=device) 207 | self.n_tasks = n_tasks 208 | self.device = device 209 | if task_weights is None: 210 | task_weights = torch.ones((n_tasks,)) 211 | if not isinstance(task_weights, torch.Tensor): 212 | task_weights = torch.tensor(task_weights) 213 | assert len(task_weights) == n_tasks 214 | self.task_weights = task_weights.to(device) 215 | 216 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks 217 | self.robust_step_size = robust_step_size 218 | self.num_groups = num_groups 219 | 220 | def get_weighted_loss(self, losses, **kwargs): 221 | adjusted_loss = losses.detach() 222 | scale = adjusted_loss.sum() / adjusted_loss 223 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss) 224 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 225 | weight = scale * self.adv_probs 226 | 227 | weight = weight.unsqueeze(1) 228 | 229 | if self.num_groups >=2: 230 | # pdb.set_trace() 231 | results = SpectralClustering(n_clusters=self.num_groups, 232 | assign_labels='cluster_qr', 233 | affinity='linear', 234 | random_state=0).fit(weight.cpu()) 235 | cluster_ids = results.labels_ 236 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device) 237 | 238 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device) 239 | cluster_ids = cluster_ids.to(torch.int64) 240 | assignment_matrix.scatter_(1, cluster_ids, 1) 241 | assignment_matrix = assignment_matrix.to(torch.float64) 242 | 243 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0) 244 | cluster_centers = cluster_centers.unsqueeze(1) 245 | 246 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1 247 | 248 | 249 | elif self.num_groups == 1: 250 | group_weight = torch.ones(self.n_tasks).to(self.device) 251 | group_weight = group_weight * torch.mean(weight) 252 | 253 | loss = torch.sum(losses * group_weight) 254 | return loss, dict(weights=torch.cat([group_weight])) 255 | 256 | class GROUP_sklearn_spectral_clustering_discretize(WeightMethod): 257 | def __init__( 258 | self, 259 | n_tasks: int, 260 | device: torch.device, 261 | num_groups: int, 262 | task_weights: Union[List[float], torch.Tensor] = None, 263 | robust_step_size: float = 0.0001, 264 | ): 265 | super().__init__(n_tasks, device=device) 266 | self.n_tasks = n_tasks 267 | self.device = device 268 | if task_weights is None: 269 | task_weights = torch.ones((n_tasks,)) 270 | if not isinstance(task_weights, torch.Tensor): 271 | task_weights = torch.tensor(task_weights) 272 | assert len(task_weights) == n_tasks 273 | self.task_weights = task_weights.to(device) 274 | 275 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks 276 | self.robust_step_size = robust_step_size 277 | self.num_groups = num_groups 278 | 279 | def get_weighted_loss(self, losses, **kwargs): 280 | adjusted_loss = losses.detach() 281 | scale = adjusted_loss.sum() / adjusted_loss 282 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss) 283 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 284 | weight = scale * self.adv_probs 285 | 286 | weight = weight.unsqueeze(1) 287 | 288 | if self.num_groups >=2: 289 | # pdb.set_trace() 290 | results = SpectralClustering(n_clusters=self.num_groups, 291 | assign_labels='discretize', 292 | affinity='linear', 293 | random_state=0).fit(weight.cpu()) 294 | cluster_ids = results.labels_ 295 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device) 296 | 297 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device) 298 | cluster_ids = cluster_ids.to(torch.int64) 299 | assignment_matrix.scatter_(1, cluster_ids, 1) 300 | assignment_matrix = assignment_matrix.to(torch.float64) 301 | 302 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0) 303 | cluster_centers = cluster_centers.unsqueeze(1) 304 | 305 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1 306 | 307 | 308 | elif self.num_groups == 1: 309 | group_weight = torch.ones(self.n_tasks).to(self.device) 310 | group_weight = group_weight * torch.mean(weight) 311 | 312 | loss = torch.sum(losses * group_weight) 313 | return loss, dict(weights=torch.cat([group_weight])) 314 | 315 | class GROUP_sklearn_spectral_clustering_kmeans(WeightMethod): 316 | def __init__( 317 | self, 318 | n_tasks: int, 319 | device: torch.device, 320 | num_groups: int, 321 | task_weights: Union[List[float], torch.Tensor] = None, 322 | robust_step_size: float = 0.0001, 323 | ): 324 | super().__init__(n_tasks, device=device) 325 | self.n_tasks = n_tasks 326 | self.device = device 327 | if task_weights is None: 328 | task_weights = torch.ones((n_tasks,)) 329 | if not isinstance(task_weights, torch.Tensor): 330 | task_weights = torch.tensor(task_weights) 331 | assert len(task_weights) == n_tasks 332 | self.task_weights = task_weights.to(device) 333 | 334 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks 335 | self.robust_step_size = robust_step_size 336 | self.num_groups = num_groups 337 | 338 | def get_weighted_loss(self, losses, **kwargs): 339 | adjusted_loss = losses.detach() 340 | scale = adjusted_loss.sum() / adjusted_loss 341 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss) 342 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 343 | weight = scale * self.adv_probs 344 | 345 | weight = weight.unsqueeze(1) 346 | 347 | if self.num_groups >=2: 348 | # pdb.set_trace() 349 | results = SpectralClustering(n_clusters=self.num_groups, 350 | assign_labels='kmeans', 351 | affinity='linear', 352 | random_state=0).fit(weight.cpu()) 353 | cluster_ids = results.labels_ 354 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device) 355 | 356 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device) 357 | cluster_ids = cluster_ids.to(torch.int64) 358 | assignment_matrix.scatter_(1, cluster_ids, 1) 359 | assignment_matrix = assignment_matrix.to(torch.float64) 360 | 361 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0) 362 | cluster_centers = cluster_centers.unsqueeze(1) 363 | 364 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1 365 | 366 | 367 | elif self.num_groups == 1: 368 | group_weight = torch.ones(self.n_tasks).to(self.device) 369 | group_weight = group_weight * torch.mean(weight) 370 | 371 | loss = torch.sum(losses * group_weight) 372 | return loss, dict(weights=torch.cat([group_weight])) 373 | 374 | class GROUP_sdp_clustering(WeightMethod): 375 | def __init__( 376 | self, 377 | n_tasks: int, 378 | device: torch.device, 379 | num_groups: int, 380 | task_weights: Union[List[float], torch.Tensor] = None, 381 | robust_step_size: float = 0.0001, 382 | ): 383 | super().__init__(n_tasks, device=device) 384 | self.n_tasks = n_tasks 385 | self.device = device 386 | if task_weights is None: 387 | task_weights = torch.ones((n_tasks,)) 388 | if not isinstance(task_weights, torch.Tensor): 389 | task_weights = torch.tensor(task_weights) 390 | assert len(task_weights) == n_tasks 391 | self.task_weights = task_weights.to(device) 392 | 393 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks 394 | self.robust_step_size = robust_step_size 395 | self.num_groups = num_groups 396 | 397 | def get_weighted_loss(self, losses, **kwargs): 398 | adjusted_loss = losses.detach() 399 | scale = adjusted_loss.sum() / adjusted_loss 400 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss) 401 | self.adv_probs = self.adv_probs / (self.adv_probs.sum()) 402 | weight = scale * self.adv_probs 403 | 404 | weight = weight.unsqueeze(1) 405 | 406 | if self.num_groups >=2: 407 | D, Q = sdp_kmeans(X=weight.cpu().numpy(), n_clusters=self.num_groups, method='cvx') 408 | assignment_matrix_init = connected_components(Q) 409 | assignment_matrix = torch.tensor(assignment_matrix_init).to(self.device) 410 | assignment_matrix = assignment_matrix.transpose(0,1).to(torch.float64) 411 | 412 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0) 413 | cluster_centers = cluster_centers.unsqueeze(1) 414 | 415 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1 416 | 417 | elif self.num_groups == 1: 418 | group_weight = torch.ones(self.n_tasks).to(self.device) 419 | group_weight = group_weight * torch.mean(weight) 420 | 421 | loss = torch.sum(losses * group_weight) 422 | return loss, dict(weights=torch.cat([group_weight])) 423 | 424 | -------------------------------------------------------------------------------- /experiments/nyuv2/models.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class _SegNet(nn.Module): 9 | """SegNet MTAN""" 10 | 11 | def __init__(self): 12 | super(_SegNet, self).__init__() 13 | # initialise network parameters 14 | filter = [64, 128, 256, 512, 512] 15 | self.class_nb = 13 16 | 17 | # define encoder decoder layers 18 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 19 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 20 | for i in range(4): 21 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 22 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 23 | 24 | # define convolution layer 25 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 26 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 27 | for i in range(4): 28 | if i == 0: 29 | self.conv_block_enc.append( 30 | self.conv_layer([filter[i + 1], filter[i + 1]]) 31 | ) 32 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 33 | else: 34 | self.conv_block_enc.append( 35 | nn.Sequential( 36 | self.conv_layer([filter[i + 1], filter[i + 1]]), 37 | self.conv_layer([filter[i + 1], filter[i + 1]]), 38 | ) 39 | ) 40 | self.conv_block_dec.append( 41 | nn.Sequential( 42 | self.conv_layer([filter[i], filter[i]]), 43 | self.conv_layer([filter[i], filter[i]]), 44 | ) 45 | ) 46 | 47 | # define task attention layers 48 | self.encoder_att = nn.ModuleList( 49 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])] 50 | ) 51 | self.decoder_att = nn.ModuleList( 52 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])] 53 | ) 54 | self.encoder_block_att = nn.ModuleList( 55 | [self.conv_layer([filter[0], filter[1]])] 56 | ) 57 | self.decoder_block_att = nn.ModuleList( 58 | [self.conv_layer([filter[0], filter[0]])] 59 | ) 60 | 61 | for j in range(3): 62 | if j < 2: 63 | self.encoder_att.append( 64 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]) 65 | ) 66 | self.decoder_att.append( 67 | nn.ModuleList( 68 | [self.att_layer([2 * filter[0], filter[0], filter[0]])] 69 | ) 70 | ) 71 | for i in range(4): 72 | self.encoder_att[j].append( 73 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]) 74 | ) 75 | self.decoder_att[j].append( 76 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]) 77 | ) 78 | 79 | for i in range(4): 80 | if i < 3: 81 | self.encoder_block_att.append( 82 | self.conv_layer([filter[i + 1], filter[i + 2]]) 83 | ) 84 | self.decoder_block_att.append( 85 | self.conv_layer([filter[i + 1], filter[i]]) 86 | ) 87 | else: 88 | self.encoder_block_att.append( 89 | self.conv_layer([filter[i + 1], filter[i + 1]]) 90 | ) 91 | self.decoder_block_att.append( 92 | self.conv_layer([filter[i + 1], filter[i + 1]]) 93 | ) 94 | 95 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True) 96 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True) 97 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True) 98 | 99 | # define pooling and unpooling functions 100 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 101 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.xavier_normal_(m.weight) 106 | nn.init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.Linear): 111 | nn.init.xavier_normal_(m.weight) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def shared_modules(self): 115 | return [ 116 | self.encoder_block, 117 | self.decoder_block, 118 | self.conv_block_enc, 119 | self.conv_block_dec, 120 | # self.encoder_att, self.decoder_att, 121 | self.encoder_block_att, 122 | self.decoder_block_att, 123 | self.down_sampling, 124 | self.up_sampling, 125 | ] 126 | 127 | def zero_grad_shared_modules(self): 128 | for mm in self.shared_modules(): 129 | mm.zero_grad() 130 | 131 | def conv_layer(self, channel, pred=False): 132 | if not pred: 133 | conv_block = nn.Sequential( 134 | nn.Conv2d( 135 | in_channels=channel[0], 136 | out_channels=channel[1], 137 | kernel_size=3, 138 | padding=1, 139 | ), 140 | nn.BatchNorm2d(num_features=channel[1]), 141 | nn.ReLU(inplace=True), 142 | ) 143 | else: 144 | conv_block = nn.Sequential( 145 | nn.Conv2d( 146 | in_channels=channel[0], 147 | out_channels=channel[0], 148 | kernel_size=3, 149 | padding=1, 150 | ), 151 | nn.Conv2d( 152 | in_channels=channel[0], 153 | out_channels=channel[1], 154 | kernel_size=1, 155 | padding=0, 156 | ), 157 | ) 158 | return conv_block 159 | 160 | def att_layer(self, channel): 161 | att_block = nn.Sequential( 162 | nn.Conv2d( 163 | in_channels=channel[0], 164 | out_channels=channel[1], 165 | kernel_size=1, 166 | padding=0, 167 | ), 168 | nn.BatchNorm2d(channel[1]), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d( 171 | in_channels=channel[1], 172 | out_channels=channel[2], 173 | kernel_size=1, 174 | padding=0, 175 | ), 176 | nn.BatchNorm2d(channel[2]), 177 | nn.Sigmoid(), 178 | ) 179 | return att_block 180 | 181 | def forward(self, x): 182 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 183 | [0] * 5 for _ in range(5) 184 | ) 185 | for i in range(5): 186 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 187 | 188 | # define attention list for tasks 189 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2)) 190 | for i in range(3): 191 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2)) 192 | for i in range(3): 193 | for j in range(5): 194 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2)) 195 | 196 | # define global shared network 197 | for i in range(5): 198 | if i == 0: 199 | g_encoder[i][0] = self.encoder_block[i](x) 200 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 201 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 202 | else: 203 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 204 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 205 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 206 | 207 | for i in range(5): 208 | if i == 0: 209 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 210 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 211 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 212 | else: 213 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 214 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 215 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 216 | 217 | # define task dependent attention module 218 | for i in range(3): 219 | for j in range(5): 220 | if j == 0: 221 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0]) 222 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 223 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 224 | atten_encoder[i][j][1] 225 | ) 226 | atten_encoder[i][j][2] = F.max_pool2d( 227 | atten_encoder[i][j][2], kernel_size=2, stride=2 228 | ) 229 | else: 230 | atten_encoder[i][j][0] = self.encoder_att[i][j]( 231 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1) 232 | ) 233 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 234 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 235 | atten_encoder[i][j][1] 236 | ) 237 | atten_encoder[i][j][2] = F.max_pool2d( 238 | atten_encoder[i][j][2], kernel_size=2, stride=2 239 | ) 240 | 241 | for j in range(5): 242 | if j == 0: 243 | atten_decoder[i][j][0] = F.interpolate( 244 | atten_encoder[i][-1][-1], 245 | scale_factor=2, 246 | mode="bilinear", 247 | align_corners=True, 248 | ) 249 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 250 | atten_decoder[i][j][0] 251 | ) 252 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 253 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 254 | ) 255 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 256 | else: 257 | atten_decoder[i][j][0] = F.interpolate( 258 | atten_decoder[i][j - 1][2], 259 | scale_factor=2, 260 | mode="bilinear", 261 | align_corners=True, 262 | ) 263 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 264 | atten_decoder[i][j][0] 265 | ) 266 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 267 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 268 | ) 269 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 270 | 271 | # define task prediction layers 272 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1) 273 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1]) 274 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1]) 275 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 276 | 277 | return ( 278 | [t1_pred, t2_pred, t3_pred], 279 | ( 280 | atten_decoder[0][-1][-1], 281 | atten_decoder[1][-1][-1], 282 | atten_decoder[2][-1][-1], 283 | ), 284 | ) 285 | 286 | 287 | class SegNetMtan(nn.Module): 288 | def __init__(self): 289 | super().__init__() 290 | self.segnet = _SegNet() 291 | 292 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 293 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 294 | 295 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 296 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 297 | 298 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 299 | """Parameters of the last shared layer. 300 | Returns 301 | ------- 302 | """ 303 | return [] 304 | 305 | def forward(self, x, return_representation=False): 306 | if return_representation: 307 | return self.segnet(x) 308 | else: 309 | pred, rep = self.segnet(x) 310 | return pred 311 | 312 | 313 | class SegNetSplit(nn.Module): 314 | def __init__(self, model_type="standard"): 315 | super(SegNetSplit, self).__init__() 316 | # initialise network parameters 317 | assert model_type in ["standard", "wide", "deep"] 318 | self.model_type = model_type 319 | if self.model_type == "wide": 320 | filter = [64, 128, 256, 512, 1024] 321 | else: 322 | filter = [64, 128, 256, 512, 512] 323 | 324 | self.class_nb = 13 325 | 326 | # define encoder decoder layers 327 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 328 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 329 | for i in range(4): 330 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 331 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 332 | 333 | # define convolution layer 334 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 335 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 336 | for i in range(4): 337 | if i == 0: 338 | self.conv_block_enc.append( 339 | self.conv_layer([filter[i + 1], filter[i + 1]]) 340 | ) 341 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 342 | else: 343 | self.conv_block_enc.append( 344 | nn.Sequential( 345 | self.conv_layer([filter[i + 1], filter[i + 1]]), 346 | self.conv_layer([filter[i + 1], filter[i + 1]]), 347 | ) 348 | ) 349 | self.conv_block_dec.append( 350 | nn.Sequential( 351 | self.conv_layer([filter[i], filter[i]]), 352 | self.conv_layer([filter[i], filter[i]]), 353 | ) 354 | ) 355 | 356 | # define task specific layers 357 | self.pred_task1 = nn.Sequential( 358 | nn.Conv2d( 359 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 360 | ), 361 | nn.Conv2d( 362 | in_channels=filter[0], 363 | out_channels=self.class_nb, 364 | kernel_size=1, 365 | padding=0, 366 | ), 367 | ) 368 | self.pred_task2 = nn.Sequential( 369 | nn.Conv2d( 370 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 371 | ), 372 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0), 373 | ) 374 | self.pred_task3 = nn.Sequential( 375 | nn.Conv2d( 376 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 377 | ), 378 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0), 379 | ) 380 | 381 | # define pooling and unpooling functions 382 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 383 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 384 | 385 | for m in self.modules(): 386 | if isinstance(m, nn.Conv2d): 387 | nn.init.xavier_normal_(m.weight) 388 | nn.init.constant_(m.bias, 0) 389 | elif isinstance(m, nn.BatchNorm2d): 390 | nn.init.constant_(m.weight, 1) 391 | nn.init.constant_(m.bias, 0) 392 | elif isinstance(m, nn.Linear): 393 | nn.init.xavier_normal_(m.weight) 394 | nn.init.constant_(m.bias, 0) 395 | 396 | # define convolutional block 397 | def conv_layer(self, channel): 398 | if self.model_type == "deep": 399 | conv_block = nn.Sequential( 400 | nn.Conv2d( 401 | in_channels=channel[0], 402 | out_channels=channel[1], 403 | kernel_size=3, 404 | padding=1, 405 | ), 406 | nn.BatchNorm2d(num_features=channel[1]), 407 | nn.ReLU(inplace=True), 408 | nn.Conv2d( 409 | in_channels=channel[1], 410 | out_channels=channel[1], 411 | kernel_size=3, 412 | padding=1, 413 | ), 414 | nn.BatchNorm2d(num_features=channel[1]), 415 | nn.ReLU(inplace=True), 416 | ) 417 | else: 418 | conv_block = nn.Sequential( 419 | nn.Conv2d( 420 | in_channels=channel[0], 421 | out_channels=channel[1], 422 | kernel_size=3, 423 | padding=1, 424 | ), 425 | nn.BatchNorm2d(num_features=channel[1]), 426 | nn.ReLU(inplace=True), 427 | ) 428 | return conv_block 429 | 430 | def forward(self, x): 431 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 432 | [0] * 5 for _ in range(5) 433 | ) 434 | for i in range(5): 435 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 436 | 437 | # global shared encoder-decoder network 438 | for i in range(5): 439 | if i == 0: 440 | g_encoder[i][0] = self.encoder_block[i](x) 441 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 442 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 443 | else: 444 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 445 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 446 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 447 | 448 | for i in range(5): 449 | if i == 0: 450 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 451 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 452 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 453 | else: 454 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 455 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 456 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 457 | 458 | # define task prediction layers 459 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1) 460 | t2_pred = self.pred_task2(g_decoder[i][1]) 461 | t3_pred = self.pred_task3(g_decoder[i][1]) 462 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 463 | 464 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][ 465 | 1 466 | ] # NOTE: last element is representation 467 | 468 | 469 | class SegNet(nn.Module): 470 | def __init__(self): 471 | super().__init__() 472 | self.segnet = SegNetSplit() 473 | 474 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 475 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 476 | 477 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 478 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 479 | 480 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 481 | """Parameters of the last shared layer. 482 | Returns 483 | ------- 484 | """ 485 | return self.segnet.conv_block_dec[-5].parameters() 486 | 487 | def forward(self, x, return_representation=False): 488 | if return_representation: 489 | return self.segnet(x) 490 | else: 491 | pred, rep = self.segnet(x) 492 | return pred 493 | -------------------------------------------------------------------------------- /methods/weight_methods.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from abc import abstractmethod 4 | from typing import Dict, List, Tuple, Union 5 | 6 | import cvxpy as cp 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from scipy.optimize import minimize 11 | 12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers 13 | from methods.cluster_methods import * 14 | 15 | class WeightMethod: 16 | def __init__(self, n_tasks: int, device: torch.device): 17 | super().__init__() 18 | self.n_tasks = n_tasks 19 | self.device = device 20 | 21 | @abstractmethod 22 | def get_weighted_loss( 23 | self, 24 | losses: torch.Tensor, 25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 26 | task_specific_parameters: Union[ 27 | List[torch.nn.parameter.Parameter], torch.Tensor 28 | ], 29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 31 | **kwargs, 32 | ): 33 | pass 34 | 35 | def backward( 36 | self, 37 | losses: torch.Tensor, 38 | shared_parameters: Union[ 39 | List[torch.nn.parameter.Parameter], torch.Tensor 40 | ] = None, 41 | task_specific_parameters: Union[ 42 | List[torch.nn.parameter.Parameter], torch.Tensor 43 | ] = None, 44 | last_shared_parameters: Union[ 45 | List[torch.nn.parameter.Parameter], torch.Tensor 46 | ] = None, 47 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 48 | **kwargs, 49 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 50 | """ 51 | 52 | Parameters 53 | ---------- 54 | losses : 55 | shared_parameters : 56 | task_specific_parameters : 57 | last_shared_parameters : parameters of last shared layer/block 58 | representation : shared representation 59 | kwargs : 60 | 61 | Returns 62 | ------- 63 | Loss, extra outputs 64 | """ 65 | loss, extra_outputs = self.get_weighted_loss( 66 | losses=losses, 67 | shared_parameters=shared_parameters, 68 | task_specific_parameters=task_specific_parameters, 69 | last_shared_parameters=last_shared_parameters, 70 | representation=representation, 71 | **kwargs, 72 | ) 73 | loss.backward() 74 | return loss, extra_outputs 75 | 76 | def __call__( 77 | self, 78 | losses: torch.Tensor, 79 | shared_parameters: Union[ 80 | List[torch.nn.parameter.Parameter], torch.Tensor 81 | ] = None, 82 | task_specific_parameters: Union[ 83 | List[torch.nn.parameter.Parameter], torch.Tensor 84 | ] = None, 85 | **kwargs, 86 | ): 87 | return self.backward( 88 | losses=losses, 89 | shared_parameters=shared_parameters, 90 | task_specific_parameters=task_specific_parameters, 91 | **kwargs, 92 | ) 93 | 94 | def parameters(self) -> List[torch.Tensor]: 95 | """return learnable parameters""" 96 | return [] 97 | 98 | class NashMTL(WeightMethod): 99 | def __init__( 100 | self, 101 | n_tasks: int, 102 | device: torch.device, 103 | max_norm: float = 1.0, 104 | update_weights_every: int = 1, 105 | optim_niter=20, 106 | ): 107 | super(NashMTL, self).__init__( 108 | n_tasks=n_tasks, 109 | device=device, 110 | ) 111 | 112 | self.optim_niter = optim_niter 113 | self.update_weights_every = update_weights_every 114 | self.max_norm = max_norm 115 | 116 | self.prvs_alpha_param = None 117 | self.normalization_factor = np.ones((1,)) 118 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks) 119 | self.step = 0.0 120 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) 121 | 122 | def _stop_criteria(self, gtg, alpha_t): 123 | return ( 124 | (self.alpha_param.value is None) 125 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) 126 | or ( 127 | np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) 128 | < 1e-6 129 | ) 130 | ) 131 | 132 | def solve_optimization(self, gtg: np.array): 133 | self.G_param.value = gtg 134 | self.normalization_factor_param.value = self.normalization_factor 135 | 136 | alpha_t = self.prvs_alpha 137 | for _ in range(self.optim_niter): 138 | self.alpha_param.value = alpha_t 139 | self.prvs_alpha_param.value = alpha_t 140 | 141 | try: 142 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) 143 | except: 144 | self.alpha_param.value = self.prvs_alpha_param.value 145 | 146 | if self._stop_criteria(gtg, alpha_t): 147 | break 148 | 149 | alpha_t = self.alpha_param.value 150 | 151 | if alpha_t is not None: 152 | self.prvs_alpha = alpha_t 153 | 154 | return self.prvs_alpha 155 | 156 | def _calc_phi_alpha_linearization(self): 157 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param 158 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param 159 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) 160 | return phi_alpha 161 | 162 | def _init_optim_problem(self): 163 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True) 164 | self.prvs_alpha_param = cp.Parameter( 165 | shape=(self.n_tasks,), value=self.prvs_alpha 166 | ) 167 | self.G_param = cp.Parameter( 168 | shape=(self.n_tasks, self.n_tasks), value=self.init_gtg 169 | ) 170 | self.normalization_factor_param = cp.Parameter( 171 | shape=(1,), value=np.array([1.0]) 172 | ) 173 | 174 | self.phi_alpha = self._calc_phi_alpha_linearization() 175 | 176 | G_alpha = self.G_param @ self.alpha_param 177 | constraint = [] 178 | for i in range(self.n_tasks): 179 | constraint.append( 180 | -cp.log(self.alpha_param[i] * self.normalization_factor_param) 181 | - cp.log(G_alpha[i]) 182 | <= 0 183 | ) 184 | obj = cp.Minimize( 185 | cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param 186 | ) 187 | self.prob = cp.Problem(obj, constraint) 188 | 189 | def get_weighted_loss( 190 | self, 191 | losses, 192 | shared_parameters, 193 | **kwargs, 194 | ): 195 | """ 196 | 197 | Parameters 198 | ---------- 199 | losses : 200 | shared_parameters : shared parameters 201 | kwargs : 202 | 203 | Returns 204 | ------- 205 | 206 | """ 207 | 208 | extra_outputs = dict() 209 | if self.step == 0: 210 | self._init_optim_problem() 211 | 212 | if (self.step % self.update_weights_every) == 0: 213 | self.step += 1 214 | 215 | grads = {} 216 | for i, loss in enumerate(losses): 217 | g = list( 218 | torch.autograd.grad( 219 | loss, 220 | shared_parameters, 221 | retain_graph=True, 222 | ) 223 | ) 224 | grad = torch.cat([torch.flatten(grad) for grad in g]) 225 | grads[i] = grad # 44117184 226 | 227 | G = torch.stack(tuple(v for v in grads.values())) # 3 * 44117184 228 | GTG = torch.mm(G, G.t()) # 3 * 3 229 | 230 | self.normalization_factor = (torch.norm(GTG).detach().cpu().numpy().reshape((1,))) 231 | GTG = GTG / self.normalization_factor.item() #!!! 232 | alpha = self.solve_optimization(GTG.cpu().detach().numpy()) 233 | alpha = torch.from_numpy(alpha) 234 | 235 | else: 236 | self.step += 1 237 | alpha = self.prvs_alpha 238 | # ----------------- 239 | # updated_gradient = alpha.unsqueeze(1).float().cuda() * G 240 | # updated_overall = updated_gradient.sum(0) 241 | # # print(updated_overall.norm() - np.sqrt(3)) 242 | # projection = updated_overall @ torch.transpose(updated_gradient, 0, 1) 243 | # # print(projection) 244 | # # cosine = projection / (updated_overall.norm() * updated_gradient.norm(dim=1)) 245 | # norm_gradient = updated_gradient / updated_gradient.norm(dim=1).unsqueeze(1) 246 | # 247 | # norm_gradient @ torch.transpose(norm_gradient, 0, 1) 248 | # norm_G = G / G.norm(dim=1).unsqueeze(1) 249 | # norm_G @ torch.transpose(norm_G, 0, 1) 250 | # ----------------- 251 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))]) 252 | extra_outputs["weights"] = alpha 253 | return weighted_loss, extra_outputs 254 | 255 | def backward( 256 | self, 257 | losses: torch.Tensor, 258 | shared_parameters: Union[ 259 | List[torch.nn.parameter.Parameter], torch.Tensor 260 | ] = None, 261 | task_specific_parameters: Union[ 262 | List[torch.nn.parameter.Parameter], torch.Tensor 263 | ] = None, 264 | last_shared_parameters: Union[ 265 | List[torch.nn.parameter.Parameter], torch.Tensor 266 | ] = None, 267 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 268 | **kwargs, 269 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 270 | loss, extra_outputs = self.get_weighted_loss( 271 | losses=losses, 272 | shared_parameters=shared_parameters, 273 | **kwargs, 274 | ) 275 | loss.backward() 276 | 277 | # make sure the solution for shared params has norm <= self.eps 278 | if self.max_norm > 0: 279 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm) 280 | 281 | return loss, extra_outputs 282 | 283 | class LinearScalarization(WeightMethod): 284 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 285 | 286 | def __init__( 287 | self, 288 | n_tasks: int, 289 | device: torch.device, 290 | task_weights: Union[List[float], torch.Tensor] = None, 291 | ): 292 | super().__init__(n_tasks, device=device) 293 | if task_weights is None: 294 | task_weights = torch.ones((n_tasks,)) 295 | if not isinstance(task_weights, torch.Tensor): 296 | task_weights = torch.tensor(task_weights) 297 | assert len(task_weights) == n_tasks 298 | self.task_weights = task_weights.to(device) 299 | 300 | def get_weighted_loss(self, losses, **kwargs): 301 | loss = torch.sum(losses * self.task_weights) 302 | return loss, dict(weights=self.task_weights) 303 | 304 | class ScaleInvariantLinearScalarization(WeightMethod): 305 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 306 | 307 | def __init__( 308 | self, 309 | n_tasks: int, 310 | device: torch.device, 311 | task_weights: Union[List[float], torch.Tensor] = None, 312 | ): 313 | super().__init__(n_tasks, device=device) 314 | if task_weights is None: 315 | task_weights = torch.ones((n_tasks,)) 316 | if not isinstance(task_weights, torch.Tensor): 317 | task_weights = torch.tensor(task_weights) 318 | assert len(task_weights) == n_tasks 319 | self.task_weights = task_weights.to(device) 320 | 321 | def get_weighted_loss(self, losses, **kwargs): 322 | loss = torch.sum(torch.log(losses) * self.task_weights) 323 | return loss, dict(weights=self.task_weights) 324 | 325 | class MGDA(WeightMethod): 326 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization 327 | Ozan Sener, Vladlen Koltun 328 | Neural Information Processing Systems (NeurIPS) 2018 329 | https://github.com/intel-isl/MultiObjectiveOptimization 330 | 331 | """ 332 | 333 | def __init__( 334 | self, n_tasks, device: torch.device, params="shared", normalization="none" 335 | ): 336 | super().__init__(n_tasks, device=device) 337 | self.solver = MinNormSolver() 338 | assert params in ["shared", "last", "rep"] 339 | self.params = params 340 | assert normalization in ["norm", "loss", "loss+", "none"] 341 | self.normalization = normalization 342 | 343 | @staticmethod 344 | def _flattening(grad): 345 | return torch.cat( 346 | tuple( 347 | g.reshape( 348 | -1, 349 | ) 350 | for i, g in enumerate(grad) 351 | ), 352 | dim=0, 353 | ) 354 | 355 | def get_weighted_loss( 356 | self, 357 | losses, 358 | shared_parameters=None, 359 | last_shared_parameters=None, 360 | representation=None, 361 | **kwargs, 362 | ): 363 | """ 364 | 365 | Parameters 366 | ---------- 367 | losses : 368 | shared_parameters : 369 | last_shared_parameters : 370 | representation : 371 | kwargs : 372 | 373 | Returns 374 | ------- 375 | 376 | """ 377 | # Our code 378 | grads = {} 379 | params = dict( 380 | rep=representation, shared=shared_parameters, last=last_shared_parameters 381 | )[self.params] 382 | for i, loss in enumerate(losses): 383 | g = list( 384 | torch.autograd.grad( 385 | loss, 386 | params, 387 | retain_graph=True, 388 | ) 389 | ) 390 | # Normalize all gradients, this is optional and not included in the paper. 391 | 392 | grads[i] = [torch.flatten(grad) for grad in g] 393 | 394 | gn = gradient_normalizers(grads, losses, self.normalization) 395 | for t in range(self.n_tasks): 396 | for gr_i in range(len(grads[t])): 397 | grads[t][gr_i] = grads[t][gr_i] / gn[t] 398 | 399 | sol, min_norm = self.solver.find_min_norm_element( 400 | [grads[t] for t in range(len(grads))] 401 | ) 402 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks 403 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))]) 404 | 405 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32))) 406 | 407 | class STL(WeightMethod): 408 | """Single task learning""" 409 | 410 | def __init__(self, n_tasks, device: torch.device, main_task): 411 | super().__init__(n_tasks, device=device) 412 | self.main_task = main_task 413 | self.weights = torch.zeros(n_tasks, device=device) 414 | self.weights[main_task] = 1.0 415 | 416 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 417 | assert len(losses) == self.n_tasks 418 | loss = losses[self.main_task] 419 | 420 | return loss, dict(weights=self.weights) 421 | 422 | class Uncertainty(WeightMethod): 423 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics` 424 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb 425 | """ 426 | 427 | def __init__(self, n_tasks, device: torch.device): 428 | super().__init__(n_tasks, device=device) 429 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 430 | 431 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 432 | loss = sum( 433 | [ 434 | 0.5 * (torch.exp(-logs) * loss + logs) 435 | for loss, logs in zip(losses, self.logsigma) 436 | ] 437 | ) 438 | 439 | return loss, dict( 440 | weights=torch.exp(-self.logsigma) 441 | ) # NOTE: not exactly task weights 442 | 443 | def parameters(self) -> List[torch.Tensor]: 444 | return [self.logsigma] 445 | 446 | class PCGrad(WeightMethod): 447 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py 448 | 449 | @misc{Pytorch-PCGrad, 450 | author = {Wei-Cheng Tseng}, 451 | title = {WeiChengTseng/Pytorch-PCGrad}, 452 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git}, 453 | year = {2020} 454 | } 455 | 456 | """ 457 | 458 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"): 459 | super().__init__(n_tasks, device=device) 460 | assert reduction in ["mean", "sum"] 461 | self.reduction = reduction 462 | 463 | def get_weighted_loss( 464 | self, 465 | losses: torch.Tensor, 466 | shared_parameters: Union[ 467 | List[torch.nn.parameter.Parameter], torch.Tensor 468 | ] = None, 469 | task_specific_parameters: Union[ 470 | List[torch.nn.parameter.Parameter], torch.Tensor 471 | ] = None, 472 | **kwargs, 473 | ): 474 | raise NotImplementedError 475 | 476 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None): 477 | # shared part 478 | shared_grads = [] 479 | for l in losses: 480 | shared_grads.append( 481 | torch.autograd.grad(l, shared_parameters, retain_graph=True) 482 | ) 483 | 484 | if isinstance(shared_parameters, torch.Tensor): 485 | shared_parameters = [shared_parameters] 486 | non_conflict_shared_grads = self._project_conflicting(shared_grads) 487 | for p, g in zip(shared_parameters, non_conflict_shared_grads): 488 | p.grad = g 489 | 490 | # task specific part 491 | if task_specific_parameters is not None: 492 | task_specific_grads = torch.autograd.grad( 493 | losses.sum(), task_specific_parameters 494 | ) 495 | if isinstance(task_specific_parameters, torch.Tensor): 496 | task_specific_parameters = [task_specific_parameters] 497 | for p, g in zip(task_specific_parameters, task_specific_grads): 498 | p.grad = g 499 | 500 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]): 501 | pc_grad = copy.deepcopy(grads) 502 | for g_i in pc_grad: 503 | random.shuffle(grads) 504 | for g_j in grads: 505 | g_i_g_j = sum( 506 | [ 507 | torch.dot(torch.flatten(grad_i), torch.flatten(grad_j)) 508 | for grad_i, grad_j in zip(g_i, g_j) 509 | ] 510 | ) 511 | if g_i_g_j < 0: 512 | g_j_norm_square = ( 513 | torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2 514 | ) 515 | for grad_i, grad_j in zip(g_i, g_j): 516 | grad_i -= g_i_g_j * grad_j / g_j_norm_square 517 | 518 | merged_grad = [sum(g) for g in zip(*pc_grad)] 519 | if self.reduction == "mean": 520 | merged_grad = [g / self.n_tasks for g in merged_grad] 521 | 522 | return merged_grad 523 | 524 | def backward( 525 | self, 526 | losses: torch.Tensor, 527 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 528 | shared_parameters: Union[ 529 | List[torch.nn.parameter.Parameter], torch.Tensor 530 | ] = None, 531 | task_specific_parameters: Union[ 532 | List[torch.nn.parameter.Parameter], torch.Tensor 533 | ] = None, 534 | **kwargs, 535 | ): 536 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters) 537 | return None, {} # NOTE: to align with all other weight methods 538 | 539 | class CAGrad(WeightMethod): 540 | def __init__(self, n_tasks, device: torch.device, c=0.4): 541 | super().__init__(n_tasks, device=device) 542 | self.c = c 543 | 544 | def get_weighted_loss( 545 | self, 546 | losses, 547 | shared_parameters, 548 | **kwargs, 549 | ): 550 | """ 551 | Parameters 552 | ---------- 553 | losses : 554 | shared_parameters : shared parameters 555 | kwargs : 556 | Returns 557 | ------- 558 | """ 559 | # NOTE: we allow only shared params for now. Need to see paper for other options. 560 | grad_dims = [] 561 | for param in shared_parameters: 562 | grad_dims.append(param.data.numel()) 563 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device) 564 | 565 | for i in range(self.n_tasks): 566 | if i < (self.n_tasks - 1): 567 | losses[i].backward(retain_graph=True) 568 | else: 569 | losses[i].backward() 570 | self.grad2vec(shared_parameters, grads, grad_dims, i) 571 | # multi_task_model.zero_grad_shared_modules() 572 | for p in shared_parameters: 573 | p.grad = None 574 | 575 | g = self.cagrad(grads, alpha=self.c, rescale=1) 576 | self.overwrite_grad(shared_parameters, g, grad_dims) 577 | 578 | def cagrad(self, grads, alpha=0.5, rescale=1): 579 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks] 580 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient 581 | 582 | x_start = np.ones(self.n_tasks) / self.n_tasks 583 | bnds = tuple((0, 1) for x in x_start) 584 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)} 585 | A = GG.numpy() 586 | b = x_start.copy() 587 | c = (alpha * g0_norm + 1e-8).item() 588 | 589 | def objfn(x): 590 | return ( 591 | x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1)) 592 | + c 593 | * np.sqrt( 594 | x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1)) 595 | + 1e-8 596 | ) 597 | ).sum() 598 | 599 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 600 | w_cpu = res.x 601 | ww = torch.Tensor(w_cpu).to(grads.device) 602 | gw = (grads * ww.view(1, -1)).sum(1) 603 | gw_norm = gw.norm() 604 | lmbda = c / (gw_norm + 1e-8) 605 | g = grads.mean(1) + lmbda * gw 606 | if rescale == 0: 607 | return g 608 | elif rescale == 1: 609 | return g / (1 + alpha ** 2) 610 | else: 611 | return g / (1 + alpha) 612 | 613 | @staticmethod 614 | def grad2vec(shared_params, grads, grad_dims, task): 615 | # store the gradients 616 | grads[:, task].fill_(0.0) 617 | cnt = 0 618 | # for mm in m.shared_modules(): 619 | # for p in mm.parameters(): 620 | 621 | for param in shared_params: 622 | grad = param.grad 623 | if grad is not None: 624 | grad_cur = grad.data.detach().clone() 625 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 626 | en = sum(grad_dims[: cnt + 1]) 627 | grads[beg:en, task].copy_(grad_cur.data.view(-1)) 628 | cnt += 1 629 | 630 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims): 631 | newgrad = newgrad * self.n_tasks # to match the sum loss 632 | cnt = 0 633 | 634 | # for mm in m.shared_modules(): 635 | # for param in mm.parameters(): 636 | for param in shared_parameters: 637 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 638 | en = sum(grad_dims[: cnt + 1]) 639 | this_grad = newgrad[beg:en].contiguous().view(param.data.size()) 640 | param.grad = this_grad.data.clone() 641 | cnt += 1 642 | 643 | def backward( 644 | self, 645 | losses: torch.Tensor, 646 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 647 | shared_parameters: Union[ 648 | List[torch.nn.parameter.Parameter], torch.Tensor 649 | ] = None, 650 | task_specific_parameters: Union[ 651 | List[torch.nn.parameter.Parameter], torch.Tensor 652 | ] = None, 653 | **kwargs, 654 | ): 655 | self.get_weighted_loss(losses, shared_parameters) 656 | return None, {} # NOTE: to align with all other weight methods 657 | 658 | class RLW(WeightMethod): 659 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf""" 660 | 661 | def __init__(self, n_tasks, device: torch.device): 662 | super().__init__(n_tasks, device=device) 663 | 664 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 665 | assert len(losses) == self.n_tasks 666 | weight = (F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 667 | loss = torch.sum(losses * weight) 668 | 669 | return loss, dict(weights=weight) 670 | 671 | class IMTLG(WeightMethod): 672 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr""" 673 | 674 | def __init__(self, n_tasks, device: torch.device): 675 | super().__init__(n_tasks, device=device) 676 | 677 | def get_weighted_loss( 678 | self, 679 | losses, 680 | shared_parameters, 681 | **kwargs, 682 | ): 683 | grads = {} 684 | norm_grads = {} 685 | 686 | for i, loss in enumerate(losses): 687 | g = list( 688 | torch.autograd.grad( 689 | loss, 690 | shared_parameters, 691 | retain_graph=True, 692 | ) 693 | ) 694 | grad = torch.cat([torch.flatten(grad) for grad in g]) 695 | norm_term = torch.norm(grad) 696 | 697 | grads[i] = grad 698 | norm_grads[i] = grad / norm_term 699 | 700 | G = torch.stack(tuple(v for v in grads.values())) 701 | D = ( 702 | G[ 703 | 0, 704 | ] 705 | - G[ 706 | 1:, 707 | ] 708 | ) 709 | 710 | U = torch.stack(tuple(v for v in norm_grads.values())) 711 | U = ( 712 | U[ 713 | 0, 714 | ] 715 | - U[ 716 | 1:, 717 | ] 718 | ) 719 | first_element = torch.matmul( 720 | G[ 721 | 0, 722 | ], 723 | U.t(), 724 | ) 725 | try: 726 | second_element = torch.inverse(torch.matmul(D, U.t())) 727 | except: 728 | # workaround for cases where matrix is singular 729 | second_element = torch.inverse( 730 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8 731 | + torch.matmul(D, U.t()) 732 | ) 733 | 734 | alpha_ = torch.matmul(first_element, second_element) 735 | alpha = torch.cat( 736 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_) 737 | ) 738 | 739 | loss = torch.sum(losses * alpha) 740 | 741 | return loss, dict(weights=alpha) 742 | 743 | class DynamicWeightAverage(WeightMethod): 744 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`. 745 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242 746 | """ 747 | 748 | def __init__( 749 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0 750 | ): 751 | """ 752 | 753 | Parameters 754 | ---------- 755 | n_tasks : 756 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses 757 | temp : 758 | """ 759 | super().__init__(n_tasks, device=device) 760 | self.iteration_window = iteration_window 761 | self.temp = temp 762 | self.running_iterations = 0 763 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32) 764 | self.weights = np.ones(n_tasks, dtype=np.float32) 765 | 766 | def get_weighted_loss(self, losses, **kwargs): 767 | 768 | cost = losses.detach().cpu().numpy() 769 | 770 | # update costs - fifo 771 | self.costs[:-1, :] = self.costs[1:, :] 772 | self.costs[-1, :] = cost 773 | 774 | if self.running_iterations > self.iteration_window: 775 | ws = self.costs[self.iteration_window :, :].mean(0) / self.costs[ 776 | : self.iteration_window, : 777 | ].mean(0) 778 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / ( 779 | np.exp(ws / self.temp) 780 | ).sum() 781 | 782 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to( 783 | losses.device 784 | ) 785 | loss = (task_weights * losses).mean() 786 | 787 | self.running_iterations += 1 788 | 789 | return loss, dict(weights=task_weights) 790 | 791 | class FAMO(WeightMethod): 792 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 793 | 794 | def __init__( 795 | self, 796 | n_tasks: int, 797 | device: torch.device, 798 | gamma: float = 1e-5, 799 | w_lr: float = 0.025, 800 | task_weights: Union[List[float], torch.Tensor] = None, 801 | max_norm: float = 1.0, 802 | ): 803 | super().__init__(n_tasks, device=device) 804 | self.min_losses = torch.zeros(n_tasks).to(device) 805 | self.w = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 806 | self.w_opt = torch.optim.Adam([self.w], lr=w_lr, weight_decay=gamma) 807 | self.max_norm = max_norm 808 | 809 | def set_min_losses(self, losses): 810 | self.min_losses = losses 811 | 812 | def get_weighted_loss(self, losses, **kwargs): 813 | self.prev_loss = losses 814 | z = F.softmax(self.w, -1) 815 | D = losses - self.min_losses + 1e-8 816 | c = (z / D).sum().detach() 817 | loss = (D.log() * z / c).sum() 818 | # return loss, {"weights": z, "logits": self.w.detach().clone()} 819 | return loss, dict(weights=torch.cat([z])) 820 | 821 | def update(self, curr_loss): 822 | delta = (self.prev_loss - self.min_losses + 1e-8).log() - \ 823 | (curr_loss - self.min_losses + 1e-8).log() 824 | with torch.enable_grad(): 825 | d = torch.autograd.grad(F.softmax(self.w, -1), 826 | self.w, 827 | grad_outputs=delta.detach())[0] 828 | self.w_opt.zero_grad() 829 | self.w.grad = d 830 | self.w_opt.step() 831 | 832 | class GradDrop(WeightMethod): 833 | def __init__(self, n_tasks, device: torch.device, max_norm=1.0): 834 | super().__init__(n_tasks, device=device) 835 | self.max_norm = max_norm 836 | 837 | def get_weighted_loss( 838 | self, 839 | losses, 840 | shared_parameters, 841 | **kwargs, 842 | ): 843 | """ 844 | Parameters 845 | ---------- 846 | losses : 847 | shared_parameters : shared parameters 848 | kwargs : 849 | Returns 850 | ------- 851 | """ 852 | # NOTE: we allow only shared params for now. Need to see paper for other options. 853 | grad_dims = [] 854 | for param in shared_parameters: 855 | grad_dims.append(param.data.numel()) 856 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device) 857 | 858 | for i in range(self.n_tasks): 859 | if i < self.n_tasks: 860 | losses[i].backward(retain_graph=True) 861 | else: 862 | losses[i].backward() 863 | self.grad2vec(shared_parameters, grads, grad_dims, i) 864 | # multi_task_model.zero_grad_shared_modules() 865 | for p in shared_parameters: 866 | p.grad = None 867 | 868 | P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1)+1e-8)) 869 | U = torch.rand_like(grads[:,0]) 870 | M = P.gt(U).view(-1,1)*grads.gt(0) + P.lt(U).view(-1,1)*grads.lt(0) 871 | g = (grads * M.float()).mean(1) 872 | self.overwrite_grad(shared_parameters, g, grad_dims) 873 | 874 | @staticmethod 875 | def grad2vec(shared_params, grads, grad_dims, task): 876 | # store the gradients 877 | grads[:, task].fill_(0.0) 878 | cnt = 0 879 | # for mm in m.shared_modules(): 880 | # for p in mm.parameters(): 881 | 882 | for param in shared_params: 883 | grad = param.grad 884 | if grad is not None: 885 | grad_cur = grad.data.detach().clone() 886 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 887 | en = sum(grad_dims[: cnt + 1]) 888 | grads[beg:en, task].copy_(grad_cur.data.view(-1)) 889 | cnt += 1 890 | 891 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims): 892 | newgrad = newgrad * self.n_tasks # to match the sum loss 893 | cnt = 0 894 | 895 | # for mm in m.shared_modules(): 896 | # for param in mm.parameters(): 897 | for param in shared_parameters: 898 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 899 | en = sum(grad_dims[: cnt + 1]) 900 | this_grad = newgrad[beg:en].contiguous().view(param.data.size()) 901 | param.grad = this_grad.data.clone() 902 | cnt += 1 903 | 904 | def backward( 905 | self, 906 | losses: torch.Tensor, 907 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 908 | shared_parameters: Union[ 909 | List[torch.nn.parameter.Parameter], torch.Tensor 910 | ] = None, 911 | task_specific_parameters: Union[ 912 | List[torch.nn.parameter.Parameter], torch.Tensor 913 | ] = None, 914 | **kwargs, 915 | ): 916 | #GTG, w = self.get_weighted_loss(losses, shared_parameters) 917 | self.get_weighted_loss(losses, shared_parameters) 918 | if self.max_norm > 0: 919 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm) 920 | return None, None # NOTE: to align with all other weight methods 921 | 922 | class WeightMethods: 923 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs): 924 | """ 925 | :param method: 926 | """ 927 | assert method in list(METHODS.keys()), f"unknown method {method}." 928 | 929 | self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 930 | 931 | def get_weighted_loss(self, losses, **kwargs): 932 | return self.method.get_weighted_loss(losses, **kwargs) 933 | 934 | def backward( 935 | self, losses, **kwargs 936 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 937 | return self.method.backward(losses, **kwargs) 938 | 939 | def __ceil__(self, losses, **kwargs): 940 | return self.backward(losses, **kwargs) 941 | 942 | def parameters(self): 943 | return self.method.parameters() 944 | 945 | METHODS = dict( 946 | stl=STL, 947 | ls=LinearScalarization, 948 | uw=Uncertainty, 949 | pcgrad=PCGrad, 950 | mgda=MGDA, 951 | cagrad=CAGrad, 952 | nashmtl=NashMTL, 953 | scaleinvls=ScaleInvariantLinearScalarization, 954 | rlw=RLW, 955 | imtl=IMTLG, 956 | dwa=DynamicWeightAverage, 957 | #========= 958 | graddrop=GradDrop, 959 | famo=FAMO, 960 | go4align=GO4ALIGN, 961 | group=GROUP, 962 | group_random=GROUP_RANDOM, 963 | #========= 964 | group_sklearn_spectral_clustering_cluster_qr=GROUP_sklearn_spectral_clustering_cluster_qr, 965 | group_sklearn_spectral_clustering_discretize=GROUP_sklearn_spectral_clustering_discretize, 966 | group_sklearn_spectral_clustering_kmeans=GROUP_sklearn_spectral_clustering_kmeans, 967 | group_sdp_clustering=GROUP_sdp_clustering, 968 | 969 | ) --------------------------------------------------------------------------------