├── experiments
├── __init__.py
├── nyuv2
│ ├── __init__.py
│ ├── README.md
│ ├── utils.py
│ ├── data.py
│ ├── trainer.py
│ └── models.py
└── utils.py
├── GO4Align.png
├── methods
├── __init__.py
├── sdp_kmeans
│ ├── __init__.py
│ ├── utils.py
│ ├── embedding.py
│ ├── nmf.py
│ └── sdp.py
├── min_norm_solvers.py
├── cluster_methods.py
└── weight_methods.py
├── requirements.txt
├── setup.py
└── README.md
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/nyuv2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/GO4Align.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autumn9999/GO4Align/HEAD/GO4Align.png
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from methods.weight_methods import (
2 | METHODS,
3 | MGDA,
4 | STL,
5 | LinearScalarization,
6 | NashMTL,
7 | PCGrad,
8 | Uncertainty,
9 | )
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.2.1
2 | numpy>=1.18.2
3 | torch>=1.4.0
4 | torchvision>=0.8.0
5 | cvxpy
6 | tqdm>=4.45.0
7 | pandas
8 | scikit-learn
9 | seaborn
10 | wandb==0.12.21
11 | plotly
--------------------------------------------------------------------------------
/methods/sdp_kmeans/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, print_function
2 | from .embedding import sdp_kmeans_embedding, spectral_embedding
3 | from .nmf import symnmf_admm
4 | from .sdp import sdp_kmeans, sdp_km, sdp_km_burer_monteiro,\
5 | sdp_km_conditional_gradient
6 | from .utils import connected_components, dot_matrix
--------------------------------------------------------------------------------
/methods/sdp_kmeans/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.sparse.csgraph import connected_components as inner_conn_comp
3 | import pdb
4 |
5 | def dot_matrix(X):
6 | X_norm = X - np.mean(X, axis=0)
7 | X_norm /= np.max(np.linalg.norm(X, axis=1))
8 | return X_norm.dot(X_norm.T)
9 |
10 |
11 | def connected_components(sym_mat, thresh=1e-4):
12 | binary_mat = sym_mat > sym_mat.max() * thresh
13 | n_comp, labels = inner_conn_comp(binary_mat, directed=False,
14 | return_labels=True)
15 | clusters = [labels == i for i in range(n_comp)]
16 | return clusters
17 |
--------------------------------------------------------------------------------
/experiments/nyuv2/README.md:
--------------------------------------------------------------------------------
1 | # NYUv2 Experiment
2 |
3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad) and [MTAN](https://github.com/lorenmt/mtan).
4 |
5 | ## Dataset
6 |
7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `.dataset/train` and `./dataset/val`.
8 |
9 | ## Evaluation
10 |
11 | To align with previous work on MTL [Liu et al. (2019)](https://arxiv.org/abs/1803.10704); [Yu et al. (2020)](https://arxiv.org/abs/2001.06782); [Liu et al. (2021)](https://arxiv.org/pdf/2110.14048.pdf) we report the test performance averaged
12 | over the last 10 epochs. Note that this averaging is not handled in the code and need to be applied by the user.
13 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from io import open
2 | from os import path
3 |
4 | from setuptools import find_packages, setup
5 |
6 | here = path.abspath(path.dirname(__file__))
7 |
8 | # get the long description from the README.md file
9 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
10 | long_description = f.read()
11 |
12 |
13 | # get reqs
14 | def requirements():
15 | list_requirements = []
16 | with open("requirements.txt") as f:
17 | for line in f:
18 | list_requirements.append(line.rstrip())
19 | return list_requirements
20 |
21 |
22 | setup(
23 | name="nashmtl",
24 | version="1.0.0", # Required
25 | description="Nash-MTL", # Optional
26 | long_description="", # Optional
27 | long_description_content_type="text/markdown", # Optional (see note above)
28 | url="", # Optional
29 | author="", # Optional
30 | author_email="", # Optional
31 | packages=find_packages(exclude=["contrib", "docs", "tests"]),
32 | # packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required
33 | python_requires=">=3.6",
34 | install_requires=requirements(), # Optional
35 | )
36 |
--------------------------------------------------------------------------------
/methods/sdp_kmeans/embedding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import eigh
3 | from .sdp import sdp_kmeans
4 |
5 |
6 | def sdp_kmeans_embedding(X, n_clusters, target_dim, ret_sdp=False,
7 | method='cvx'):
8 | D, Q = sdp_kmeans(X, n_clusters, method=method)
9 | Y = spectral_embedding(Q, target_dim=target_dim, discard_first=True)
10 | if ret_sdp:
11 | return Y, D, Q
12 | else:
13 | return Y
14 |
15 |
16 | def spectral_embedding(mat, target_dim, gramian=True, discard_first=True):
17 | if discard_first:
18 | last = -1
19 | first = target_dim - last
20 | else:
21 | first = target_dim
22 | last = None
23 | if not gramian:
24 | mat = mat - mat.mean(axis=0)
25 | mat = mat.dot(mat.T)
26 | eigvals, eigvecs = eigh(mat)
27 |
28 | sl = slice(-first, last)
29 | eigvecs = eigvecs[:, sl]
30 | eigvals_crop = eigvals[sl]
31 | Y = eigvecs.dot(np.diag(np.sqrt(eigvals_crop)))
32 | Y = Y[:, ::-1]
33 |
34 | variance_explaned(eigvals, eigvals_crop)
35 | return Y
36 |
37 |
38 | def variance_explaned(eigvals, eigvals_crop):
39 | eigvals_crop[eigvals_crop < 0] = 0
40 | eigvals[eigvals < 0] = 0
41 | var = np.sum(eigvals_crop) / np.sum(eigvals)
42 | print('Variance explained:', var)
43 |
--------------------------------------------------------------------------------
/methods/sdp_kmeans/nmf.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | import numpy as np
3 |
4 |
5 | def symnmf_admm(A, k, H=None, maxiter=1e3, tol=1e-5, sigma=1):
6 | """
7 | A is a symmetric matrix
8 | Solves || A - W.dot(W.T) ||_F^2 s.t. W >= 0
9 | """
10 | A = A.copy()
11 | A[A < 0] = 0
12 |
13 | n = A.shape[0]
14 | if n != A.shape[1]:
15 | raise ValueError('A must be a symmetric matrix!')
16 |
17 | if H is None:
18 | H = np.sqrt(A.mean() / k) * np.random.randn(n, k)
19 | np.abs(H, H)
20 |
21 | Gamma = np.zeros((n, k))
22 | id_k = np.identity(k)
23 | step = 1
24 |
25 | error = []
26 | for i in range(int(maxiter)):
27 | temp = np.linalg.inv(H.T.dot(H) + sigma * id_k)
28 | W = (A.dot(H) + sigma * H - Gamma).dot(temp)
29 | W = np.maximum(W, 0)
30 | temp = np.linalg.inv(W.T.dot(W) + sigma * id_k)
31 | H = (A.dot(W) + sigma * W + Gamma).dot(temp)
32 | H = np.maximum(H, 0)
33 | Gamma += step * sigma * (W - H)
34 |
35 | error.append(np.linalg.norm(W - H) / np.linalg.norm(W))
36 | if i > 0 and np.abs(error[-1]) < tol:
37 | break
38 |
39 | return W
40 |
41 |
42 | def symnmf_gram_admm(A, k, H=None, maxiter=1e3, tol=1e-5, sigma=1):
43 | """
44 | Solves || A.dot(A.T) - W.dot(W.T) ||_F^2 s.t. W >= 0
45 | """
46 | A = A.copy()
47 | A[A < 0] = 0
48 |
49 | if H is None:
50 | n = A.shape[0]
51 | H = np.sqrt(A.mean() / k) * np.random.randn(n, k)
52 | np.abs(H, H)
53 |
54 | Gamma = np.zeros((n, k))
55 | id_k = np.identity(k)
56 | step = 1
57 |
58 | error = []
59 | for i in range(int(maxiter)):
60 | temp = np.linalg.inv(H.T.dot(H) + sigma * id_k)
61 | W = (A.dot(A.T.dot(H)) + sigma * H - Gamma).dot(temp)
62 | W = np.maximum(W, 0)
63 | temp = np.linalg.inv(W.T.dot(W) + sigma * id_k)
64 | H = (A.dot(A.T.dot(W)) + sigma * W + Gamma).dot(temp)
65 | H = np.maximum(H, 0)
66 | Gamma += step * sigma * (W - H)
67 |
68 | error.append(np.linalg.norm(W - H) / np.linalg.norm(W))
69 | if i > 0 and np.abs(error[-1]) < tol:
70 | break
71 |
72 | return W
73 |
--------------------------------------------------------------------------------
/experiments/nyuv2/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ConfMatrix(object):
6 | def __init__(self, num_classes):
7 | self.num_classes = num_classes
8 | self.mat = None
9 |
10 | def update(self, pred, target):
11 | n = self.num_classes
12 | if self.mat is None:
13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
14 | with torch.no_grad():
15 | k = (target >= 0) & (target < n)
16 | inds = n * target[k].to(torch.int64) + pred[k]
17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
18 |
19 | def get_metrics(self):
20 | h = self.mat.float()
21 | acc = torch.diag(h).sum() / h.sum()
22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy()
24 |
25 |
26 | def depth_error(x_pred, x_output):
27 | device = x_pred.device
28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
29 | x_pred_true = x_pred.masked_select(binary_mask)
30 | x_output_true = x_output.masked_select(binary_mask)
31 | abs_err = torch.abs(x_pred_true - x_output_true)
32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
33 | return (
34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
35 | ).item(), (
36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
37 | ).item()
38 |
39 |
40 | def normal_error(x_pred, x_output):
41 | binary_mask = torch.sum(x_output, dim=1) != 0
42 | error = (
43 | torch.acos(
44 | torch.clamp(
45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1
46 | )
47 | )
48 | .detach()
49 | .cpu()
50 | .numpy()
51 | )
52 | error = np.degrees(error)
53 | return (
54 | np.mean(error),
55 | np.median(error),
56 | np.mean(error < 11.25),
57 | np.mean(error < 22.5),
58 | np.mean(error < 30),
59 | )
60 |
61 |
62 | # for calculating \Delta_m
63 | delta_stats = [
64 | "mean iou",
65 | "pix acc",
66 | "abs err",
67 | "rel err",
68 | "mean",
69 | "median",
70 | "<11.25",
71 | "<22.5",
72 | "<30",
73 | ]
74 | BASE = np.array(
75 | [0.3830, 0.6376, 0.6754, 0.2780, 25.01, 19.21, 0.3014, 0.5720, 0.6915]
76 | ) # base results from CAGrad
77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1])
78 | KK = np.ones(9) * -1
79 |
80 |
81 | def delta_fn(a):
82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage
83 |
--------------------------------------------------------------------------------
/experiments/nyuv2/data.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data.dataset import Dataset
9 |
10 | """Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py
11 |
12 | """
13 |
14 |
15 | class RandomScaleCrop(object):
16 | """
17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
18 | """
19 |
20 | def __init__(self, scale=[1.0, 1.2, 1.5]):
21 | self.scale = scale
22 |
23 | def __call__(self, img, label, depth, normal):
24 | height, width = img.shape[-2:]
25 | sc = self.scale[random.randint(0, len(self.scale) - 1)]
26 | h, w = int(height / sc), int(width / sc)
27 | i = random.randint(0, height - h)
28 | j = random.randint(0, width - w)
29 | img_ = F.interpolate(
30 | img[None, :, i : i + h, j : j + w],
31 | size=(height, width),
32 | mode="bilinear",
33 | align_corners=True,
34 | ).squeeze(0)
35 | label_ = (
36 | F.interpolate(
37 | label[None, None, i : i + h, j : j + w],
38 | size=(height, width),
39 | mode="nearest",
40 | )
41 | .squeeze(0)
42 | .squeeze(0)
43 | )
44 | depth_ = F.interpolate(
45 | depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest"
46 | ).squeeze(0)
47 | normal_ = F.interpolate(
48 | normal[None, :, i : i + h, j : j + w],
49 | size=(height, width),
50 | mode="bilinear",
51 | align_corners=True,
52 | ).squeeze(0)
53 | return img_, label_, depth_ / sc, normal_
54 |
55 |
56 | class NYUv2(Dataset):
57 | """
58 | We could further improve the performance with the data augmentation of NYUv2 defined in:
59 | [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
60 | [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
61 | [3] Mti-net: Multiscale task interaction networks for multi-task learning
62 | 1. Random scale in a selected raio 1.0, 1.2, and 1.5.
63 | 2. Random horizontal flip.
64 | Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
65 | """
66 |
67 | def __init__(self, root, train=True, augmentation=False):
68 | self.train = train
69 | self.root = os.path.expanduser(root)
70 | self.augmentation = augmentation
71 |
72 | # read the data file
73 | if train:
74 | self.data_path = root + "/train"
75 | else:
76 | self.data_path = root + "/val"
77 |
78 | # calculate data length
79 | self.data_len = len(
80 | fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
81 | )
82 |
83 | def __getitem__(self, index):
84 | # load data from the pre-processed npy files
85 | image = torch.from_numpy(
86 | np.moveaxis(
87 | np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0
88 | )
89 | )
90 | semantic = torch.from_numpy(
91 | np.load(self.data_path + "/label/{:d}.npy".format(index))
92 | )
93 | depth = torch.from_numpy(
94 | np.moveaxis(
95 | np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0
96 | )
97 | )
98 | normal = torch.from_numpy(
99 | np.moveaxis(
100 | np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0
101 | )
102 | )
103 |
104 | # apply data augmentation if required
105 | if self.augmentation:
106 | image, semantic, depth, normal = RandomScaleCrop()(
107 | image, semantic, depth, normal
108 | )
109 | if torch.rand(1) < 0.5:
110 | image = torch.flip(image, dims=[2])
111 | semantic = torch.flip(semantic, dims=[1])
112 | depth = torch.flip(depth, dims=[2])
113 | normal = torch.flip(normal, dims=[2])
114 | normal[0, :, :] = -normal[0, :, :]
115 |
116 | return image.float(), semantic.float(), depth.float(), normal.float()
117 |
118 | def __len__(self):
119 | return self.data_len
120 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import random
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from methods import METHODS
11 |
12 |
13 | def list_of_float(arg):
14 | return list(map(float, arg.split(',')))
15 |
16 | def str_to_list(string):
17 | return [float(s) for s in string.split(",")]
18 |
19 |
20 | def str_or_float(value):
21 | try:
22 | return float(value)
23 | except:
24 | return value
25 |
26 |
27 | def str2bool(v):
28 | if isinstance(v, bool):
29 | return v
30 | if v.lower() in ("yes", "true", "t", "y", "1"):
31 | return True
32 | elif v.lower() in ("no", "false", "f", "n", "0"):
33 | return False
34 | else:
35 | raise argparse.ArgumentTypeError("Boolean value expected.")
36 |
37 |
38 | common_parser = argparse.ArgumentParser(add_help=False)
39 | common_parser.add_argument("--data-path", type=Path, help="path to data")
40 | common_parser.add_argument("--n-epochs", type=int, default=300)
41 | common_parser.add_argument("--batch-size", type=int, default=120, help="batch size")
42 | common_parser.add_argument("--method", type=str, choices=list(METHODS.keys()), help="MTL weight method")
43 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
44 | common_parser.add_argument(
45 | "--method-params-lr",
46 | type=float,
47 | default=0.025,
48 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting",
49 | )
50 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
51 | common_parser.add_argument("--seed", type=int, default=42, help="seed value")
52 | # NashMTL
53 | common_parser.add_argument(
54 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations"
55 | )
56 | common_parser.add_argument(
57 | "--update-weights-every",
58 | type=int,
59 | default=1,
60 | help="update task weights every x iterations.",
61 | )
62 | # stl
63 | common_parser.add_argument(
64 | "--main-task",
65 | type=int,
66 | default=0,
67 | help="main task for stl. Ignored if method != stl",
68 | )
69 | # cagrad
70 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.")
71 | # dwa
72 | # dwa
73 | common_parser.add_argument(
74 | "--dwa-temp",
75 | type=float,
76 | default=2.0,
77 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.",
78 | )
79 |
80 |
81 | def count_parameters(model):
82 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
83 |
84 |
85 | def set_logger():
86 | logging.basicConfig(
87 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
88 | level=logging.INFO,
89 | )
90 |
91 |
92 | def set_seed(seed):
93 | """for reproducibility
94 | :param seed:
95 | :return:
96 | """
97 | np.random.seed(seed)
98 | random.seed(seed)
99 |
100 | torch.manual_seed(seed)
101 | if torch.cuda.is_available():
102 | torch.cuda.manual_seed(seed)
103 | torch.cuda.manual_seed_all(seed)
104 |
105 | torch.backends.cudnn.enabled = True
106 | torch.backends.cudnn.benchmark = False
107 | torch.backends.cudnn.deterministic = True
108 |
109 |
110 | # def get_device(no_cuda=False, gpus="0"):
111 | # return torch.device(
112 | # f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu"
113 | # )
114 |
115 | #-------------------
116 | def get_device(no_cuda=False, gpus=0):
117 | return torch.device(
118 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu"
119 | )
120 | #-------------------
121 |
122 | def extract_weight_method_parameters_from_args(args):
123 | weight_methods_parameters = defaultdict(dict)
124 | weight_methods_parameters.update(
125 | dict(
126 | nashmtl=dict( update_weights_every=args.update_weights_every, optim_niter=args.nashmtl_optim_niter),
127 | stl=dict(main_task=args.main_task),
128 | cagrad=dict(c=args.c),
129 | dwa=dict(temp=args.dwa_temp),
130 | #-------------
131 | go4align=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size),
132 | group=dict(task_weights=args.task_weights, robust_step_size=args.robust_step_size),
133 | group_random=dict(task_weights=args.task_weights, robust_step_size=args.robust_step_size),
134 | #-------------rebuttal
135 | group_sklearn_spectral_clustering_cluster_qr=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size),
136 | group_sklearn_spectral_clustering_discretize=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size),
137 | group_sklearn_spectral_clustering_kmeans=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size),
138 | group_sdp_clustering=dict(num_groups=args.num_groups, robust_step_size=args.robust_step_size),
139 | )
140 | )
141 | return weight_methods_parameters
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GO4Align
2 |
3 | Welcome to the official repository for "GO4Align: Group Optimization for Multi-Task Alignment," one effective and efficient approach to multi-task optimization.
4 |
5 |
6 | ### Project Webpage
7 |
8 | Details will be available soon.
9 |
10 | ### Abstract
11 |
12 | This paper proposes GO4Align, a multi-task optimization approach that tackles task imbalance by explicitly aligning the optimization across tasks.
13 | To achieve this, we design an adaptive group risk minimization strategy, compromising two crucial techniques in implementation:
14 | - **dynamical group assignment**, which clusters similar tasks based on task interactions;
15 | - **risk-guided group indicators**, which exploit consistent task correlations with risk information from previous iterations.
16 |
17 | Comprehensive experimental results on diverse typical benchmarks demonstrate our method's performance superiority with even lower computational costs.
18 |
19 | ### Paper
20 |
21 | [The preprint of our paper](https://arxiv.org/abs/2404.06486) is available on arXiv.
22 |
23 | ### Framework of Adaptive Group Risk Minimization
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ---
32 |
33 | ### Setup Environment
34 |
35 | We recommend using miniconda to create a virtual environment for running the code:
36 | ```bash
37 | conda create -n go4align python=3.9.7
38 | conda activate go4align
39 | conda install pytorch==1.13.1 torchvision==0.14.1 cudatoolkit=12.3 -c pytorch
40 | conda install pyg -c pyg -c conda-forge
41 | ```
42 |
43 | Install the package by running the following commands in the terminal:
44 | ```bash
45 | git clone https://github.com/autumn9999/GO4Align.git
46 | cd GO4Align
47 | pip install -e .
48 | ```
49 |
50 | GPU: NVIDIA A100-SXM4-40GB
51 |
52 | ### Download Datasets
53 |
54 | This work is evaluated on several multi-task learning benchmarks:
55 |
56 | 1. [NYUv2](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0) (3 tasks), where the link is provided by the previous MTO work [CAGrad](https://github.com/Cranial-XIX/CAGrad.git).
57 | 2. [CityScapes](https://www.dropbox.com/sh/gaw6vh6qusoyms6/AADwWi0Tp3E3M4B2xzeGlsEna?dl=0) (2 tasks), where the link is provided by [CAGrad](https://github.com/Cranial-XIX/CAGrad.git).
58 | 3. [CelebA](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg) (40 tasks). Details can be found in the previous MTO work [FAMO](https://github.com/Cranial-XIX/FAMO).
59 | 4. QM9 (11 tasks), which can be downloaded automatically by `torch_geometric.datasets`. Details can be found in [FAMO](https://github.com/Cranial-XIX/FAMO/blob/main/experiments/quantum_chemistry/trainer.py).
60 |
61 |
62 |
63 | ### Run Experiments
64 | Here we provide experiments code for NYUv2. To run the experiment with other benchmark, please refer to unified APIs in [NashMTL](https://github.com/AvivNavon/nash-mtl) or [FAMO](https://github.com/Cranial-XIX/FAMO).
65 |
66 | ```bash
67 | cd experiment/nyuv2
68 | python trainer.py --method go4align
69 | ```
70 |
71 | We also support the following MTL methods as alternatives.
72 |
73 | | Method (code name) | Paper (notes) |
74 | |:-------------------------:|:-------------------------------------------------------------------------------------------------------------------------------:|
75 | | Gradient-oriented methods | --------------------------------------------------------------------------------- |
76 | | MGDA | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/pdf/1810.04650) |
77 | | PCGrad | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
78 | | CAGrad | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) |
79 | | IMTL-G | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) |
80 | | NashMTL | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) |
81 | | Loss-oriented methods | --------------------------------------------------------------------------------- |
82 | | LS | - (equal weighting) |
83 | | SI | - (see Nash-MTL paper for details) |
84 | | RLW | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) |
85 | | DWA | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/pdf/1803.10704) |
86 | | UW | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) |
87 | | FAMO | [FAMO: Fast Adaptive Multitask Optimization](https://arxiv.org/pdf/2306.03792) |
88 |
89 |
90 |
91 | ### Citation
92 | This repo is built upon [NashMTL](https://github.com/AvivNavon/nash-mtl) or [FAMO](https://github.com/Cranial-XIX/FAMO). If our work **GO4Align** is helpful in your research or projects, please cite the following papers:
93 |
94 | ```bib
95 | @article{shen2024go4align,
96 | title={GO4Align: Group Optimization for Multi-Task Alignment},
97 | author={Shen, Jiayi and Wang, Cheems and Xiao, Zehao and Van Noord, Nanne and Worring, Marcel},
98 | journal={arXiv preprint arXiv:2404.06486},
99 | year={2024}
100 | }
101 |
102 | @article{liu2024famo,
103 | title={Famo: Fast adaptive multitask optimization},
104 | author={Liu, Bo and Feng, Yihao and Stone, Peter and Liu, Qiang},
105 | journal={Advances in Neural Information Processing Systems},
106 | volume={36},
107 | year={2024}
108 | }
109 |
110 | @article{navon2022multi,
111 | title={Multi-task learning as a bargaining game},
112 | author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
113 | journal={arXiv preprint arXiv:2202.01017},
114 | year={2022}
115 | }
116 | ```
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/methods/min_norm_solvers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | # This code is from
6 | # Multi-Task Learning as Multi-Objective Optimization
7 | # Ozan Sener, Vladlen Koltun
8 | # Neural Information Processing Systems (NeurIPS) 2018
9 | # https://github.com/intel-isl/MultiObjectiveOptimization
10 | class MinNormSolver:
11 | MAX_ITER = 250
12 | STOP_CRIT = 1e-5
13 |
14 | @staticmethod
15 | def _min_norm_element_from2(v1v1, v1v2, v2v2):
16 | """
17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
18 | d is the distance (objective) optimzed
19 | v1v1 =
20 | v1v2 =
21 | v2v2 =
22 | """
23 | if v1v2 >= v1v1:
24 | # Case: Fig 1, third column
25 | gamma = 0.999
26 | cost = v1v1
27 | return gamma, cost
28 | if v1v2 >= v2v2:
29 | # Case: Fig 1, first column
30 | gamma = 0.001
31 | cost = v2v2
32 | return gamma, cost
33 | # Case: Fig 1, second column
34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
35 | cost = v2v2 + gamma * (v1v2 - v2v2)
36 | return gamma, cost
37 |
38 | @staticmethod
39 | def _min_norm_2d(vecs, dps):
40 | """
41 | Find the minimum norm solution as combination of two points
42 | This is correct only in 2D
43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
44 | """
45 | dmin = 1e8
46 | for i in range(len(vecs)):
47 | for j in range(i + 1, len(vecs)):
48 | if (i, j) not in dps:
49 | dps[(i, j)] = 0.0
50 | for k in range(len(vecs[i])):
51 | dps[(i, j)] += torch.dot(
52 | vecs[i][k], vecs[j][k]
53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0]
54 | dps[(j, i)] = dps[(i, j)]
55 | if (i, i) not in dps:
56 | dps[(i, i)] = 0.0
57 | for k in range(len(vecs[i])):
58 | dps[(i, i)] += torch.dot(
59 | vecs[i][k], vecs[i][k]
60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0]
61 | if (j, j) not in dps:
62 | dps[(j, j)] = 0.0
63 | for k in range(len(vecs[i])):
64 | dps[(j, j)] += torch.dot(
65 | vecs[j][k], vecs[j][k]
66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0]
67 | c, d = MinNormSolver._min_norm_element_from2(
68 | dps[(i, i)], dps[(i, j)], dps[(j, j)]
69 | )
70 | if d < dmin:
71 | dmin = d
72 | sol = [(i, j), c, d]
73 | return sol, dps
74 |
75 | @staticmethod
76 | def _projection2simplex(y):
77 | """
78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
79 | """
80 | m = len(y)
81 | sorted_y = np.flip(np.sort(y), axis=0)
82 | tmpsum = 0.0
83 | tmax_f = (np.sum(y) - 1.0) / m
84 | for i in range(m - 1):
85 | tmpsum += sorted_y[i]
86 | tmax = (tmpsum - 1) / (i + 1.0)
87 | if tmax > sorted_y[i + 1]:
88 | tmax_f = tmax
89 | break
90 | return np.maximum(y - tmax_f, np.zeros(y.shape))
91 |
92 | @staticmethod
93 | def _next_point(cur_val, grad, n):
94 | proj_grad = grad - (np.sum(grad) / n)
95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
97 |
98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7)
99 | t = 1
100 | if len(tm1[tm1 > 1e-7]) > 0:
101 | t = np.min(tm1[tm1 > 1e-7])
102 | if len(tm2[tm2 > 1e-7]) > 0:
103 | t = min(t, np.min(tm2[tm2 > 1e-7]))
104 |
105 | next_point = proj_grad * t + cur_val
106 | next_point = MinNormSolver._projection2simplex(next_point)
107 | return next_point
108 |
109 | @staticmethod
110 | def find_min_norm_element(vecs):
111 | """
112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
116 | """
117 | # Solution lying at the combination of two points
118 | dps = {}
119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
120 |
121 | n = len(vecs)
122 | sol_vec = np.zeros(n)
123 | sol_vec[init_sol[0][0]] = init_sol[1]
124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
125 |
126 | if n < 3:
127 | # This is optimal for n=2, so return the solution
128 | return sol_vec, init_sol[2]
129 |
130 | iter_count = 0
131 |
132 | grad_mat = np.zeros((n, n))
133 | for i in range(n):
134 | for j in range(n):
135 | grad_mat[i, j] = dps[(i, j)]
136 |
137 | while iter_count < MinNormSolver.MAX_ITER:
138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
140 | # Re-compute the inner products for line search
141 | v1v1 = 0.0
142 | v1v2 = 0.0
143 | v2v2 = 0.0
144 | for i in range(n):
145 | for j in range(n):
146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point
151 | change = new_sol_vec - sol_vec
152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
153 | return sol_vec, nd
154 | sol_vec = new_sol_vec
155 |
156 | @staticmethod
157 | def find_min_norm_element_FW(vecs):
158 | """
159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
163 | """
164 | # Solution lying at the combination of two points
165 | dps = {}
166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
167 |
168 | n = len(vecs)
169 | sol_vec = np.zeros(n)
170 | sol_vec[init_sol[0][0]] = init_sol[1]
171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
172 |
173 | if n < 3:
174 | # This is optimal for n=2, so return the solution
175 | return sol_vec, init_sol[2]
176 |
177 | iter_count = 0
178 |
179 | grad_mat = np.zeros((n, n))
180 | for i in range(n):
181 | for j in range(n):
182 | grad_mat[i, j] = dps[(i, j)]
183 |
184 | while iter_count < MinNormSolver.MAX_ITER:
185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec))
186 |
187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
189 | v2v2 = grad_mat[t_iter, t_iter]
190 |
191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
192 | new_sol_vec = nc * sol_vec
193 | new_sol_vec[t_iter] += 1 - nc
194 |
195 | change = new_sol_vec - sol_vec
196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
197 | return sol_vec, nd
198 | sol_vec = new_sol_vec
199 |
200 |
201 | def gradient_normalizers(grads, losses, normalization_type):
202 | gn = {}
203 | if normalization_type == "norm":
204 | for t in grads:
205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
206 | elif normalization_type == "loss":
207 | for t in grads:
208 | gn[t] = losses[t]
209 | elif normalization_type == "loss+":
210 | for t in grads:
211 | gn[t] = losses[t] * np.sqrt(
212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])
213 | )
214 | elif normalization_type == "none":
215 | for t in grads:
216 | gn[t] = 1.0
217 | else:
218 | print("ERROR: Invalid Normalization Type")
219 | return gn
220 |
--------------------------------------------------------------------------------
/methods/sdp_kmeans/sdp.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | import cvxpy as cp
3 | from functools import partial
4 | import numpy as np
5 | import scipy.sparse as sp
6 | from scipy.optimize import minimize
7 | from .nmf import symnmf_gram_admm, symnmf_admm
8 | from .utils import dot_matrix
9 |
10 |
11 | def sdp_kmeans(X, n_clusters, method='cvx'):
12 | if method == 'cvx':
13 | D = dot_matrix(X)
14 | Q = sdp_km(D, n_clusters)
15 | elif method == 'cgm':
16 | D = dot_matrix(X)
17 | Q = sdp_km_conditional_gradient(D, n_clusters)
18 | elif method == 'bm':
19 | Y = sdp_km_burer_monteiro(X, n_clusters)
20 | D = dot_matrix(X)
21 | Q = Y.dot(Y.T)
22 | else:
23 | raise ValueError('The method should be one of "cvx" and "bm"')
24 |
25 | return D, Q
26 |
27 |
28 | def sdp_km(D, n_clusters, max_iters=5000, eps=1e-5):
29 | ones = np.ones((D.shape[0], 1))
30 | try:
31 | Z = cp.Variable(D.shape, PSD=True)
32 | except TypeError:
33 | Z = cp.Semidef(D.shape[0])
34 | objective = cp.Maximize(cp.trace(D * Z))
35 | constraints = [Z >= 0,
36 | Z * ones == ones,
37 | cp.trace(Z) == n_clusters]
38 |
39 | prob = cp.Problem(objective, constraints)
40 | prob.solve(solver=cp.SCS, verbose=False, max_iters=max_iters, eps=eps)
41 |
42 | Q = np.asarray(Z.value)
43 | rs = Q.sum(axis=1)
44 | print('Q', Q.min(), Q.max(), '|',
45 | rs.min(), rs.max(), '|',
46 | np.trace(Q), np.trace(D.dot(Q)))
47 | print('Final objective', np.trace(D.dot(Q)))
48 |
49 | return Q
50 |
51 |
52 | def sdp_km_burer_monteiro(X, n_clusters, rank=None, maxiter=1e3, tol=1e-5):
53 | if rank is None:
54 | rank = 8 * n_clusters
55 |
56 | X_norm = X - np.mean(X, axis=0)
57 | if X_norm.shape[0] > X_norm.shape[1]:
58 | cov = X_norm.T.dot(X_norm)
59 | else:
60 | XXt = X_norm.dot(X_norm.T)
61 | cov = XXt
62 | X_norm /= np.trace(cov.dot(cov)) ** 0.25
63 |
64 | if X_norm.shape[0] <= X_norm.shape[1]:
65 | XXt /= np.trace(cov.dot(cov)) ** 0.5
66 |
67 | Y_shape = (len(X), rank)
68 | ones = np.ones((len(X), 1))
69 |
70 | def lagrangian(x, lambda1, lambda2, sigma1, sigma2):
71 | Y = x.reshape(Y_shape)
72 |
73 | if X_norm.shape[0] > X_norm.shape[1]:
74 | YtX = Y.T.dot(X_norm)
75 | obj = -np.trace(YtX.dot(YtX.T))
76 | else:
77 | obj = -np.trace(Y.T.dot(XXt).dot(Y))
78 |
79 | trYtY_minus_nclusters = np.trace(Y.T.dot(Y)) - n_clusters
80 | obj -= lambda1 * trYtY_minus_nclusters
81 | obj += .5 * sigma1 * trYtY_minus_nclusters ** 2
82 |
83 | YYt1_minus_1 = Y.dot(Y.T.dot(ones)) - ones
84 | obj -= lambda2.T.dot(YYt1_minus_1)[0, 0]
85 | obj += .5 * sigma2 * np.sum(YYt1_minus_1 ** 2)
86 |
87 | return obj
88 |
89 | def grad(x, lambda1, lambda2, sigma1, sigma2):
90 | Y = x.reshape(Y_shape)
91 |
92 | if X_norm.shape[0] > X_norm.shape[1]:
93 | delta = -2 * X_norm.dot(X_norm.T.dot(Y))
94 | else:
95 | delta = -2 * XXt.dot(Y)
96 |
97 | YtY = Y.T.dot(Y)
98 | delta -= 2 * (lambda1
99 | -sigma1 * (np.trace(Y.T.dot(Y)) - n_clusters)) * Y
100 |
101 | delta -= ones.dot(lambda2.T.dot(Y)) + lambda2.dot(ones.T.dot(Y))
102 |
103 | Yt1 = Y.T.dot(ones)
104 | delta += sigma2 * (Y.dot(Yt1).dot(Yt1.T) + ones.dot(Yt1.T).dot(YtY)
105 | -2 * ones.dot(Yt1.T))
106 |
107 | return delta.flatten()
108 |
109 | if X_norm.shape[0] > X_norm.shape[1]:
110 | Y = symnmf_gram_admm(X_norm, rank)
111 | else:
112 | Y = symnmf_admm(XXt, rank)
113 |
114 | lambda1 = 0.
115 | lambda2 = np.zeros((len(X), 1))
116 | sigma1 = 1
117 | sigma2 = 1
118 | step = 1
119 |
120 | error = []
121 | for n_iter in range(int(maxiter)):
122 | fun = partial(lagrangian, lambda1=lambda1, lambda2=lambda2,
123 | sigma1=sigma1, sigma2=sigma2)
124 | jac = partial(grad, lambda1=lambda1, lambda2=lambda2,
125 | sigma1=sigma1, sigma2=sigma2)
126 | bounds = [(0, 1)] * np.prod(Y_shape)
127 |
128 | Y_old = Y.copy()
129 | res = minimize(fun, Y.flatten(), jac=jac, bounds=bounds,
130 | method='L-BFGS-B',)
131 | Y = res.x.reshape(Y_shape)
132 |
133 | lambda1 -= step * sigma1 * (np.trace(Y.T.dot(Y)) - n_clusters)
134 | lambda2 -= step * sigma2 * (Y.dot(Y.T.dot(ones)) - ones)
135 |
136 | error.append(np.linalg.norm(Y - Y_old) / np.linalg.norm(Y_old))
137 |
138 | if error[-1] < tol:
139 | break
140 |
141 | return Y
142 |
143 |
144 | def sdp_km_conditional_gradient(D, n_clusters, max_iter=2e3,
145 | stop_tol_max=1e-3, stop_tol_rmse=1e-4,
146 | n_inner_iter=15,
147 | use_line_search=False,
148 | verbose=False, track_stats=False):
149 | n = len(D)
150 | one_over_n = 1. / n
151 |
152 | def lagrangian(Q, lagrange_lower_bound, t):
153 | obj = -np.sum(D * Q)
154 | obj += np.sum(lagrange_lower_bound * (Q + one_over_n))
155 | penalty = (t + 1) ** 0.5
156 | obj += penalty * np.sum(np.minimum(Q + one_over_n, 0) ** 2)
157 | return obj
158 |
159 | def gradient(Q, lagrange_lower_bound, t):
160 | delta = -D
161 |
162 | delta += lagrange_lower_bound
163 | penalty = (t + 1) ** 0.5
164 | delta += penalty * np.minimum(Q + one_over_n, 0)
165 | return delta
166 |
167 | def solve_lp(grad, t):
168 | # The following commented 2 lines are the mathematically
169 | # friendly version of the 2 lines following them.
170 | # ortho_mat = np.eye(n) - np.ones_like(D) / n
171 | # A = ortho_mat.dot(grad).dot(ortho_mat)
172 | grad11 = np.broadcast_to(grad.sum(axis=1) / n, (n, n))
173 | A = grad - grad11 - grad11.T + grad.sum() / (n ** 2)
174 |
175 | tol = (t + 1) ** -1
176 | return sp.linalg.eigsh(A, k=1, which='SA', tol=tol)
177 |
178 | def line_search(update, current, lagrange_lower_bound, t, tol=1e-5):
179 | gr = (np.sqrt(5) + 1) / 2
180 | a = 0
181 | b = 1
182 |
183 | c = b - (b - a) / gr
184 | d = a + (b - a) / gr
185 | while abs(c - d) > tol:
186 | convex_comb_c = c * update + (1 - c) * current
187 | convex_comb_d = d * update + (1 - d) * current
188 | fc = lagrangian(convex_comb_c, lagrange_lower_bound, t)
189 | fd = lagrangian(convex_comb_d, lagrange_lower_bound, t)
190 | if fc < fd:
191 | b = d
192 | else:
193 | a = c
194 |
195 | # we recompute both c and d here to avoid loss of precision
196 | # which may lead to incorrect results or infinite loop
197 | c = b - (b - a) / gr
198 | d = a + (b - a) / gr
199 |
200 | return (b + a) / 2
201 |
202 |
203 | if track_stats or verbose:
204 | rmse_list = []
205 | obj_value_list = []
206 |
207 | Q = np.zeros_like(D)
208 | lagrange_lower_bound = np.zeros(D.shape)
209 | step = 1
210 |
211 | for t in range(int(max_iter)):
212 | for inner_it in range(n_inner_iter):
213 | grad = gradient(Q, lagrange_lower_bound, t)
214 | s, v, = solve_lp(grad, inner_it)
215 |
216 | if s < 0:
217 | update = (n_clusters - 1) * np.outer(v, v)
218 | if use_line_search:
219 | eta = line_search(update, Q, lagrange_lower_bound,
220 | t * n_inner_iter + inner_it)
221 | else:
222 | eta = 2. / (t * n_inner_iter + inner_it + 2)
223 | Q = (1 - eta) * Q + eta * update
224 |
225 | Q_nneg = Q + one_over_n
226 |
227 | rmse = np.sqrt(np.mean(Q_nneg[Q_nneg < 0] ** 2))
228 | max_error = np.abs(np.min(Q_nneg[Q_nneg < 0])) / (n_clusters / n)
229 |
230 | if track_stats or verbose:
231 | obj_value_list.append(np.trace(D.dot(Q)))
232 | rmse_list.append(rmse)
233 |
234 | lagrange_lower_bound += step * Q_nneg
235 | np.minimum(lagrange_lower_bound, 0, out=lagrange_lower_bound)
236 |
237 | if verbose and t % 10 == 0:
238 | row_sum = Q.sum(axis=1)
239 | print('iteration', t, '|',
240 | -one_over_n, Q.min(), rmse, max_error, '|',
241 | row_sum.min(), row_sum.max(), '|',
242 | np.trace(Q), np.trace(D.dot(Q)), '|',
243 | eta)
244 |
245 | if max_error < stop_tol_max and rmse < stop_tol_rmse:
246 | if verbose:
247 | row_sum = Q.sum(axis=1)
248 | print('iteration', t, '|',
249 | -one_over_n, Q.min(), rmse, max_error, '|',
250 | row_sum.min(), row_sum.max(), '|',
251 | np.trace(Q), np.trace(D.dot(Q)), '|',
252 | eta)
253 | break
254 |
255 | Q += one_over_n
256 |
257 | if verbose:
258 | row_sum = np.mean(Q.sum(axis=1))
259 | print('sum constraint', row_sum.min(), row_sum.max())
260 | print('trace constraint', np.trace(Q))
261 | print('nonnegative constraint', np.min(Q), np.mean(np.minimum(Q, 0)))
262 |
263 | print('final objective', np.trace(D.dot(Q)))
264 |
265 | if track_stats:
266 | return Q, rmse_list, obj_value_list
267 | else:
268 | return Q
269 |
--------------------------------------------------------------------------------
/experiments/nyuv2/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import wandb
3 | from argparse import ArgumentParser
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data import DataLoader
9 | from tqdm import trange
10 |
11 | from experiments.nyuv2.data import NYUv2
12 | from experiments.nyuv2.models import SegNet, SegNetMtan
13 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error
14 | from experiments.utils import (
15 | common_parser,
16 | extract_weight_method_parameters_from_args,
17 | get_device,
18 | set_logger,
19 | set_seed,
20 | str2bool,
21 | list_of_float,
22 | )
23 | from methods.weight_methods import WeightMethods
24 |
25 | #-----------------
26 | import pdb
27 | import os
28 | import time
29 | import sys
30 | #-----------------
31 |
32 | set_logger()
33 |
34 | #-----------------
35 | def log_string(file_out, out_str, print_out=True):
36 | file_out.write(out_str+'\n')
37 | file_out.flush()
38 | if print_out:
39 | print(out_str)
40 | # -----------------
41 |
42 | def calc_loss(x_pred, x_output, task_type):
43 | device = x_pred.device
44 |
45 | # binary mark to mask out undefined pixel space
46 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device)
47 |
48 | if task_type == "semantic":
49 | # semantic loss: depth-wise cross entropy
50 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
51 |
52 | if task_type == "depth":
53 | # depth loss: l1 norm
54 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(
55 | binary_mask, as_tuple=False
56 | ).size(0)
57 |
58 | if task_type == "normal":
59 | # normal loss: dot product
60 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(
61 | binary_mask, as_tuple=False
62 | ).size(0)
63 |
64 | return loss
65 |
66 |
67 | def main(args, file_out, file_weight_out, file_loss_before, file_loss_after, path, lr, bs, device):
68 | # ----
69 | # Nets
70 | # ---
71 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model]
72 | model = model.to(device)
73 |
74 | # dataset and dataloaders
75 | log_str = (
76 | "Applying data augmentation on NYUv2."
77 | if args.apply_augmentation
78 | else "Standard training strategy without data augmentation."
79 | )
80 | logging.info(log_str)
81 |
82 | nyuv2_train_set = NYUv2(
83 | root=path.as_posix(), train=True, augmentation=args.apply_augmentation
84 | )
85 | nyuv2_test_set = NYUv2(root=path.as_posix(), train=False)
86 |
87 | train_loader = torch.utils.data.DataLoader(
88 | dataset=nyuv2_train_set, batch_size=bs, shuffle=True
89 | )
90 |
91 | test_loader = torch.utils.data.DataLoader(
92 | dataset=nyuv2_test_set, batch_size=bs, shuffle=False
93 | )
94 |
95 | # weight method
96 | weight_methods_parameters = extract_weight_method_parameters_from_args(args)
97 | weight_method = WeightMethods(args.method, n_tasks=3, device=device, **weight_methods_parameters[args.method])
98 |
99 | # optimizer
100 | optimizer = torch.optim.Adam(
101 | [
102 | dict(params=model.parameters(), lr=lr),
103 | dict(params=weight_method.parameters(), lr=args.method_params_lr),
104 | ],
105 | )
106 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
107 |
108 | epochs = args.n_epochs
109 | epoch_iter = trange(epochs)
110 | train_batch = len(train_loader)
111 | test_batch = len(test_loader)
112 | avg_cost = np.zeros([epochs, 24], dtype=np.float32)
113 | # -----------------
114 | avg_overall_performance = np.zeros([epochs, 2], dtype=np.float32)
115 | # -----------------
116 | custom_step = -1
117 | conf_mat = ConfMatrix(model.segnet.class_nb)
118 |
119 | # training
120 | for epoch in epoch_iter:
121 | # ----
122 | start = time.time()
123 | # ----
124 | cost = np.zeros(24, dtype=np.float32)
125 |
126 | for j, batch in enumerate(train_loader):
127 | #-----------------
128 | # pdb.set_trace()
129 | if args.debugging:
130 | if j == 2:
131 | break
132 | # -----------------
133 | custom_step += 1
134 | model.train()
135 | optimizer.zero_grad()
136 |
137 | train_data, train_label, train_depth, train_normal = batch
138 | train_data, train_label = train_data.to(device), train_label.long().to(
139 | device
140 | )
141 | train_depth, train_normal = train_depth.to(device), train_normal.to(device)
142 |
143 | train_pred, features = model(train_data, return_representation=True)
144 |
145 | losses = torch.stack(
146 | (
147 | calc_loss(train_pred[0], train_label, "semantic"),
148 | calc_loss(train_pred[1], train_depth, "depth"),
149 | calc_loss(train_pred[2], train_normal, "normal"),
150 | )
151 | )
152 |
153 | # pdb.set_trace()
154 | loss, extra_outputs = weight_method.backward(
155 | losses=losses,
156 | shared_parameters=list(model.shared_parameters()),
157 | task_specific_parameters=list(model.task_specific_parameters()),
158 | last_shared_parameters=list(model.last_shared_parameters()),
159 | representation=features,
160 | )
161 | # pdb.set_trace()
162 |
163 | optimizer.step()
164 |
165 | #----------updated from FAMO
166 | if "famo" in args.method:
167 | with torch.no_grad():
168 | train_pred = model(train_data, return_representation=False)
169 | new_losses = torch.stack(
170 | (
171 | calc_loss(train_pred[0], train_label, "semantic"),
172 | calc_loss(train_pred[1], train_depth, "depth"),
173 | calc_loss(train_pred[2], train_normal, "normal"),
174 | )
175 | )
176 | weight_method.method.update(new_losses.detach())
177 | #----------
178 |
179 | # accumulate label prediction for every pixel in training images
180 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
181 |
182 | cost[0] = losses[0].item()
183 | cost[3] = losses[1].item()
184 | cost[4], cost[5] = depth_error(train_pred[1], train_depth)
185 | cost[6] = losses[2].item()
186 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
187 | avg_cost[epoch, :12] += cost[:12] / train_batch
188 |
189 | epoch_iter.set_description(
190 | f"[{epoch+1} {j+1}/{train_batch}] semantic loss: {losses[0].item():.3f}, "
191 | f"depth loss: {losses[1].item():.3f}, "
192 | f"normal loss: {losses[2].item():.3f}"
193 | )
194 |
195 | # scheduler
196 | scheduler.step()
197 |
198 | # compute mIoU and acc
199 | avg_cost[epoch, 1:3] = conf_mat.get_metrics()
200 |
201 | # todo: move evaluate to function?
202 | # evaluating test data
203 | model.eval()
204 | conf_mat = ConfMatrix(model.segnet.class_nb)
205 | with torch.no_grad(): # operations inside don't track history
206 | test_dataset = iter(test_loader)
207 | for k in range(test_batch):
208 | # -----------------
209 | # pdb.set_trace()
210 | if args.debugging:
211 | if k == 2:
212 | break
213 | # -----------------
214 | test_data, test_label, test_depth, test_normal = next(test_dataset) #test_dataset.next()#.next is deprecated
215 | test_data, test_label = test_data.to(device), test_label.long().to(
216 | device
217 | )
218 | test_depth, test_normal = test_depth.to(device), test_normal.to(device)
219 |
220 | test_pred = model(test_data)
221 | test_loss = torch.stack(
222 | (
223 | calc_loss(test_pred[0], test_label, "semantic"),
224 | calc_loss(test_pred[1], test_depth, "depth"),
225 | calc_loss(test_pred[2], test_normal, "normal"),
226 | )
227 | )
228 |
229 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
230 |
231 | cost[12] = test_loss[0].item()
232 | cost[15] = test_loss[1].item()
233 | cost[16], cost[17] = depth_error(test_pred[1], test_depth)
234 | cost[18] = test_loss[2].item()
235 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(
236 | test_pred[2], test_normal
237 | )
238 | avg_cost[epoch, 12:] += cost[12:] / test_batch
239 | # compute mIoU and acc
240 | avg_cost[epoch, 13:15] = conf_mat.get_metrics()
241 |
242 | # Test Delta_m
243 | test_delta_m = delta_fn(
244 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]]
245 | )
246 |
247 | # -----------------
248 | avg_overall_performance[epoch, 0] = test_delta_m
249 | avg_overall_performance[epoch, 1] = avg_overall_performance[:epoch+1, 0].min()
250 | # -----------------
251 |
252 | # print results
253 | # print(
254 | # f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR "
255 | # f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m (test)"
256 | # )
257 | # print("")
258 |
259 | # ----
260 | end = time.time()
261 | log_string(file_out, 'Epoch: {:04d} | Epoch time {:.4f}'.format(epoch, end - start))
262 | # if args.method in {"famo"}:
263 | log_string(file_weight_out, 'Epoch {:04d} | {}'.format(epoch, extra_outputs["weights"]), False)
264 | log_string(file_loss_before, 'Epoch {:04d} | {}'.format(epoch, losses.detach().cpu()), False)
265 | log_string(file_loss_after, 'Epoch {:04d} | {}'.format(epoch, loss.detach().cpu()), False)
266 | # ----
267 |
268 | log_string(file_out, f"Epoch: {epoch:04d} | TRAIN: "
269 | f"{avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]*100:.2f} {avg_cost[epoch, 2]*100:.2f} | "
270 | f"{avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | "
271 | f"{avg_cost[epoch, 6]:.4f} {avg_cost[epoch, 7]:.2f} {avg_cost[epoch, 8]:.2f} {avg_cost[epoch, 9]*100:.2f} {avg_cost[epoch, 10]*100:.2f} {avg_cost[epoch, 11]*100:.2f} || "
272 | f"TEST: "
273 | f"{avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]*100:.2f} {avg_cost[epoch, 14]*100:.2f} | "
274 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | "
275 | f"{avg_cost[epoch, 18]:.4f} {avg_cost[epoch, 19]:.2f} {avg_cost[epoch, 20]:.2f} {avg_cost[epoch, 21]*100:.2f} {avg_cost[epoch, 22]*100:.2f} {avg_cost[epoch, 23]*100:.2f} "
276 | f"| {test_delta_m:.3f} {avg_overall_performance[epoch, -1]:.3f}"
277 | )
278 |
279 | if wandb.run is not None:
280 | wandb.log({"Train Semantic Loss": avg_cost[epoch, 0]}, step=epoch)
281 | wandb.log({"Train Mean IoU": avg_cost[epoch, 1]}, step=epoch)
282 | wandb.log({"Train Pixel Accuracy": avg_cost[epoch, 2]}, step=epoch)
283 | wandb.log({"Train Depth Loss": avg_cost[epoch, 3]}, step=epoch)
284 | wandb.log({"Train Absolute Error": avg_cost[epoch, 4]}, step=epoch)
285 | wandb.log({"Train Relative Error": avg_cost[epoch, 5]}, step=epoch)
286 | wandb.log({"Train Normal Loss": avg_cost[epoch, 6]}, step=epoch)
287 | wandb.log({"Train Loss Mean": avg_cost[epoch, 7]}, step=epoch)
288 | wandb.log({"Train Loss Med": avg_cost[epoch, 8]}, step=epoch)
289 | wandb.log({"Train Loss <11.25": avg_cost[epoch, 9]}, step=epoch)
290 | wandb.log({"Train Loss <22.5": avg_cost[epoch, 10]}, step=epoch)
291 | wandb.log({"Train Loss <30": avg_cost[epoch, 11]}, step=epoch)
292 |
293 | wandb.log({"Test Semantic Loss": avg_cost[epoch, 12]}, step=epoch)
294 | wandb.log({"Test Mean IoU": avg_cost[epoch, 13]}, step=epoch)
295 | wandb.log({"Test Pixel Accuracy": avg_cost[epoch, 14]}, step=epoch)
296 | wandb.log({"Test Depth Loss": avg_cost[epoch, 15]}, step=epoch)
297 | wandb.log({"Test Absolute Error": avg_cost[epoch, 16]}, step=epoch)
298 | wandb.log({"Test Relative Error": avg_cost[epoch, 17]}, step=epoch)
299 | wandb.log({"Test Normal Loss": avg_cost[epoch, 18]}, step=epoch)
300 | wandb.log({"Test Loss Mean": avg_cost[epoch, 19]}, step=epoch)
301 | wandb.log({"Test Loss Med": avg_cost[epoch, 20]}, step=epoch)
302 | wandb.log({"Test Loss <11.25": avg_cost[epoch, 21]}, step=epoch)
303 | wandb.log({"Test Loss <22.5": avg_cost[epoch, 22]}, step=epoch)
304 | wandb.log({"Test Loss <30": avg_cost[epoch, 23]}, step=epoch)
305 | wandb.log({"Test ∆m": test_delta_m}, step=epoch)
306 |
307 | # #-------------------------------------------
308 | # wandb.log({"Weight_seg": extra_outputs["weights"][0].item()}, step=epoch)
309 | # wandb.log({"Weight_depth": extra_outputs["weights"][1].item()}, step=epoch)
310 | # wandb.log({"Weight_normal": extra_outputs["weights"][2].item()}, step=epoch)
311 | # # -------------------------------------------
312 |
313 | # -----------------
314 | # final output
315 | log_string(file_out, f"FORMAT: MEAN_IOU PIX_ACC | ABS_ERR REL_ERR | MEAN MED <11.25 <22.5 <30 | ∆m (test)")
316 | log_string(file_out,
317 | f"Last_10_TEST: "
318 | f"{avg_cost[-10:, 13].mean()*100:.2f} {avg_cost[-10:, 14].mean()*100:.2f} "
319 | f"{avg_cost[-10:, 16].mean():.4f} {avg_cost[-10:, 17].mean():.4f} "
320 | f"{avg_cost[-10:, 19].mean():.2f} {avg_cost[-10:, 20].mean():.2f} {avg_cost[-10:, 21].mean()*100:.2f} {avg_cost[-10:, 22].mean()*100:.2f} {avg_cost[-10:, 23].mean()*100:.2f} "
321 | f"{avg_overall_performance[-10:, 0].mean():.3f}"
322 | )
323 | # -----------------
324 |
325 | # pdb.set_trace()
326 |
327 | if __name__ == "__main__":
328 | parser = ArgumentParser("NYUv2", parents=[common_parser])
329 | parser.set_defaults(
330 | data_path = "../../../../../../../dataset/SIMO/nyuv2",
331 | lr=1e-4,
332 | n_epochs=200,
333 | batch_size=2,
334 | )
335 |
336 | parser.add_argument("--model", type=str, default="mtan", choices=["segnet", "mtan"], help="model type")
337 | parser.add_argument("--apply-augmentation", type=str2bool, default=True, help="data augmentations")
338 | # parser.add_argument("--wandb_project", type=str, default="nashmtl_nyuv2", help="Name of Weights & Biases Project.")
339 | # parser.add_argument("--wandb_entity", type=str, default="jia-yi9999", help="Name of Weights & Biases Entity.")
340 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.")
341 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.")
342 | # -----------------
343 | parser.add_argument('--debugging', action='store_true', help='with debugging')
344 | parser.add_argument('--log_name', default="logs_segnet_mtan/log", type=str, help='log name')
345 | parser.add_argument('--robust_step_size', default=0.0001, type=float, help='for our method')
346 | parser.add_argument('--task_weights', default="0.1,0.1,0.8", type=list_of_float, help='for group')
347 | parser.add_argument('--num_groups', default=1, type=int, help='number of groups')
348 | # -----------------
349 | args = parser.parse_args()
350 |
351 | # set seed
352 | set_seed(args.seed)
353 | #-----------------
354 | if args.method == "go4align":
355 | log_name = args.log_name + '_Method={}_Ngroups={}_Seed={}'.format(args.method, args.num_groups, args.seed)
356 | elif args.method == "group":
357 | log_name = args.log_name + '_Method={}_{}_{}_{}_Seed={}'.format(args.method, args.task_weights[0], args.task_weights[1], args.task_weights[2], args.seed)
358 | else:
359 | log_name = args.log_name + '_Method={}_Seed={}'.format(args.method, args.seed)
360 |
361 | os.system("mkdir -p " + log_name)
362 | file_out = open(log_name + "/train_log.txt", "w")
363 | file_weight_out = open(log_name + "/weight_log.txt", "w")
364 | file_loss_before = open(log_name + "/loss_before.txt", "w")
365 | file_loss_after = open(log_name+ "/loss_after.txt", "w")
366 |
367 | os.system("mkdir -p " + log_name + '/files')
368 | os.system('cp %s %s' % ('*.py', os.path.join(log_name, 'files')))
369 | print(get_device(gpus=args.gpu))
370 | #--------------------
371 |
372 | if args.wandb_project is not None:
373 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=log_name, config=args)
374 |
375 | device = get_device(gpus=args.gpu)
376 | main(args, file_out, file_weight_out, file_loss_before, file_loss_after, path=args.data_path, lr=args.lr, bs=args.batch_size, device=device)
377 |
378 | if wandb.run is not None:
379 | wandb.finish()
380 |
381 | # -----------------
382 | st = ' '
383 | log_string(file_out, st.join(sys.argv))
384 | file_out.close()
385 | #--------------------
--------------------------------------------------------------------------------
/methods/cluster_methods.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from abc import abstractmethod
4 | from typing import Dict, List, Tuple, Union
5 |
6 | import cvxpy as cp
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from scipy.optimize import minimize
11 |
12 | # from kmeans_pytorch import kmeans # training on gpus
13 | from sklearn.cluster import KMeans, SpectralClustering # training on cpu
14 | # assign labels after the Laplacian embedding.
15 | # The cluster_qr method [5] directly extract clusters from eigenvectors in spectral clustering.
16 | # Simple, direct, and efficient multi-way spectral clustering, 2019 Anil Damle, Victor Minden, Lexing Ying
17 |
18 | from .sdp_kmeans import sdp_kmeans, connected_components, spectral_embedding
19 | # https://github.com/simonsfoundation/sdp_kmeans.git
20 | # Mariano Tepper, Anirvan Sengupta, Dmitri Chklovskii, The surprising secret identity of the semidefinite relaxation of K-means: manifold learning, 2017
21 |
22 |
23 | class WeightMethod:
24 | def __init__(self, n_tasks: int, device: torch.device):
25 | super().__init__()
26 | self.n_tasks = n_tasks
27 | self.device = device
28 |
29 | @abstractmethod
30 | def get_weighted_loss(
31 | self,
32 | losses: torch.Tensor,
33 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
34 | task_specific_parameters: Union[
35 | List[torch.nn.parameter.Parameter], torch.Tensor
36 | ],
37 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
38 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
39 | **kwargs,
40 | ):
41 | pass
42 |
43 | def backward(
44 | self,
45 | losses: torch.Tensor,
46 | shared_parameters: Union[
47 | List[torch.nn.parameter.Parameter], torch.Tensor
48 | ] = None,
49 | task_specific_parameters: Union[
50 | List[torch.nn.parameter.Parameter], torch.Tensor
51 | ] = None,
52 | last_shared_parameters: Union[
53 | List[torch.nn.parameter.Parameter], torch.Tensor
54 | ] = None,
55 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
56 | **kwargs,
57 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
58 | """
59 |
60 | Parameters
61 | ----------
62 | losses :
63 | shared_parameters :
64 | task_specific_parameters :
65 | last_shared_parameters : parameters of last shared layer/block
66 | representation : shared representation
67 | kwargs :
68 |
69 | Returns
70 | -------
71 | Loss, extra outputs
72 | """
73 | loss, extra_outputs = self.get_weighted_loss(
74 | losses=losses,
75 | shared_parameters=shared_parameters,
76 | task_specific_parameters=task_specific_parameters,
77 | last_shared_parameters=last_shared_parameters,
78 | representation=representation,
79 | **kwargs,
80 | )
81 | loss.backward()
82 | return loss, extra_outputs
83 |
84 | def __call__(
85 | self,
86 | losses: torch.Tensor,
87 | shared_parameters: Union[
88 | List[torch.nn.parameter.Parameter], torch.Tensor
89 | ] = None,
90 | task_specific_parameters: Union[
91 | List[torch.nn.parameter.Parameter], torch.Tensor
92 | ] = None,
93 | **kwargs,
94 | ):
95 | return self.backward(
96 | losses=losses,
97 | shared_parameters=shared_parameters,
98 | task_specific_parameters=task_specific_parameters,
99 | **kwargs,
100 | )
101 |
102 | def parameters(self) -> List[torch.Tensor]:
103 | """return learnable parameters"""
104 | return []
105 |
106 | class GO4ALIGN(WeightMethod):
107 | def __init__(
108 | self,
109 | n_tasks: int,
110 | device: torch.device,
111 | num_groups: int,
112 | task_weights: Union[List[float], torch.Tensor] = None,
113 | robust_step_size: float = 0.0001,
114 | ):
115 | super().__init__(n_tasks, device=device)
116 | self.n_tasks = n_tasks
117 | self.device = device
118 | if task_weights is None:
119 | task_weights = torch.ones((n_tasks,))
120 | if not isinstance(task_weights, torch.Tensor):
121 | task_weights = torch.tensor(task_weights)
122 | assert len(task_weights) == n_tasks
123 | self.task_weights = task_weights.to(device)
124 |
125 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks
126 | self.robust_step_size = robust_step_size
127 | self.num_groups = num_groups
128 |
129 | def get_weighted_loss(self, losses, **kwargs):
130 | adjusted_loss = losses.detach()
131 | scale = adjusted_loss.sum() / adjusted_loss
132 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss)
133 | self.adv_probs = self.adv_probs / (self.adv_probs.sum())
134 | weight = scale * self.adv_probs
135 |
136 | weight = weight.unsqueeze(1)
137 |
138 | if self.num_groups >=2:
139 | cluster_ids_x, cluster_centers = kmeans(X=weight, num_clusters=self.num_groups, distance='euclidean', device=self.device)
140 | mask = torch.zeros(self.n_tasks, self.num_groups).to(self.device)
141 | cluster_ids = cluster_ids_x.unsqueeze(1).to(self.device)
142 | cluster_centers = cluster_centers.to(self.device)
143 | mask.scatter_(1, cluster_ids, 1)
144 | kmeans_weight = torch.mm(mask,cluster_centers).squeeze(1)
145 |
146 | elif self.num_groups == 1:
147 | kmeans_weight = torch.ones(self.n_tasks).to(self.device)
148 | kmeans_weight = kmeans_weight * torch.mean(weight)
149 |
150 | loss = torch.sum(losses * kmeans_weight)
151 | return loss, dict(weights=torch.cat([kmeans_weight]))
152 |
153 | class GROUP(WeightMethod):
154 | def __init__(
155 | self,
156 | n_tasks: int,
157 | device: torch.device,
158 | task_weights: Union[List[float], torch.Tensor] = None,
159 | robust_step_size: float = 0.0001,
160 | ):
161 | super().__init__(n_tasks, device=device)
162 | self.n_tasks = n_tasks
163 | if task_weights is None:
164 | task_weights = torch.ones((n_tasks,))
165 | if not isinstance(task_weights, torch.Tensor):
166 | task_weights = torch.tensor(task_weights)
167 | assert len(task_weights) == n_tasks
168 | self.task_weights = task_weights.to(device)
169 |
170 | def get_weighted_loss(self, losses, **kwargs):
171 | loss = torch.sum(losses * self.task_weights) * self.n_tasks
172 | return loss, dict(weights=torch.cat([self.task_weights]))
173 |
174 | class GROUP_RANDOM(WeightMethod):
175 | def __init__(
176 | self,
177 | n_tasks: int,
178 | device: torch.device,
179 | task_weights: Union[List[float], torch.Tensor] = None,
180 | robust_step_size: float = 0.0001,
181 | ):
182 | super().__init__(n_tasks, device=device)
183 | self.n_tasks = n_tasks
184 | if task_weights is None:
185 | task_weights = torch.ones((n_tasks,))
186 | if not isinstance(task_weights, torch.Tensor):
187 | task_weights = torch.tensor(task_weights)
188 | assert len(task_weights) == n_tasks
189 | self.task_weights = task_weights.to(device)
190 |
191 | def get_weighted_loss(self, losses, **kwargs):
192 | idx = torch.randperm(self.task_weights.shape[0])
193 | weights = self.task_weights[idx]
194 | loss = torch.sum(losses * weights) * self.n_tasks
195 | return loss, dict(weights=torch.cat([weights]))
196 |
197 | class GROUP_sklearn_spectral_clustering_cluster_qr(WeightMethod):
198 | def __init__(
199 | self,
200 | n_tasks: int,
201 | device: torch.device,
202 | num_groups: int,
203 | task_weights: Union[List[float], torch.Tensor] = None,
204 | robust_step_size: float = 0.0001,
205 | ):
206 | super().__init__(n_tasks, device=device)
207 | self.n_tasks = n_tasks
208 | self.device = device
209 | if task_weights is None:
210 | task_weights = torch.ones((n_tasks,))
211 | if not isinstance(task_weights, torch.Tensor):
212 | task_weights = torch.tensor(task_weights)
213 | assert len(task_weights) == n_tasks
214 | self.task_weights = task_weights.to(device)
215 |
216 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks
217 | self.robust_step_size = robust_step_size
218 | self.num_groups = num_groups
219 |
220 | def get_weighted_loss(self, losses, **kwargs):
221 | adjusted_loss = losses.detach()
222 | scale = adjusted_loss.sum() / adjusted_loss
223 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss)
224 | self.adv_probs = self.adv_probs / (self.adv_probs.sum())
225 | weight = scale * self.adv_probs
226 |
227 | weight = weight.unsqueeze(1)
228 |
229 | if self.num_groups >=2:
230 | # pdb.set_trace()
231 | results = SpectralClustering(n_clusters=self.num_groups,
232 | assign_labels='cluster_qr',
233 | affinity='linear',
234 | random_state=0).fit(weight.cpu())
235 | cluster_ids = results.labels_
236 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device)
237 |
238 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device)
239 | cluster_ids = cluster_ids.to(torch.int64)
240 | assignment_matrix.scatter_(1, cluster_ids, 1)
241 | assignment_matrix = assignment_matrix.to(torch.float64)
242 |
243 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0)
244 | cluster_centers = cluster_centers.unsqueeze(1)
245 |
246 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1
247 |
248 |
249 | elif self.num_groups == 1:
250 | group_weight = torch.ones(self.n_tasks).to(self.device)
251 | group_weight = group_weight * torch.mean(weight)
252 |
253 | loss = torch.sum(losses * group_weight)
254 | return loss, dict(weights=torch.cat([group_weight]))
255 |
256 | class GROUP_sklearn_spectral_clustering_discretize(WeightMethod):
257 | def __init__(
258 | self,
259 | n_tasks: int,
260 | device: torch.device,
261 | num_groups: int,
262 | task_weights: Union[List[float], torch.Tensor] = None,
263 | robust_step_size: float = 0.0001,
264 | ):
265 | super().__init__(n_tasks, device=device)
266 | self.n_tasks = n_tasks
267 | self.device = device
268 | if task_weights is None:
269 | task_weights = torch.ones((n_tasks,))
270 | if not isinstance(task_weights, torch.Tensor):
271 | task_weights = torch.tensor(task_weights)
272 | assert len(task_weights) == n_tasks
273 | self.task_weights = task_weights.to(device)
274 |
275 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks
276 | self.robust_step_size = robust_step_size
277 | self.num_groups = num_groups
278 |
279 | def get_weighted_loss(self, losses, **kwargs):
280 | adjusted_loss = losses.detach()
281 | scale = adjusted_loss.sum() / adjusted_loss
282 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss)
283 | self.adv_probs = self.adv_probs / (self.adv_probs.sum())
284 | weight = scale * self.adv_probs
285 |
286 | weight = weight.unsqueeze(1)
287 |
288 | if self.num_groups >=2:
289 | # pdb.set_trace()
290 | results = SpectralClustering(n_clusters=self.num_groups,
291 | assign_labels='discretize',
292 | affinity='linear',
293 | random_state=0).fit(weight.cpu())
294 | cluster_ids = results.labels_
295 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device)
296 |
297 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device)
298 | cluster_ids = cluster_ids.to(torch.int64)
299 | assignment_matrix.scatter_(1, cluster_ids, 1)
300 | assignment_matrix = assignment_matrix.to(torch.float64)
301 |
302 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0)
303 | cluster_centers = cluster_centers.unsqueeze(1)
304 |
305 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1
306 |
307 |
308 | elif self.num_groups == 1:
309 | group_weight = torch.ones(self.n_tasks).to(self.device)
310 | group_weight = group_weight * torch.mean(weight)
311 |
312 | loss = torch.sum(losses * group_weight)
313 | return loss, dict(weights=torch.cat([group_weight]))
314 |
315 | class GROUP_sklearn_spectral_clustering_kmeans(WeightMethod):
316 | def __init__(
317 | self,
318 | n_tasks: int,
319 | device: torch.device,
320 | num_groups: int,
321 | task_weights: Union[List[float], torch.Tensor] = None,
322 | robust_step_size: float = 0.0001,
323 | ):
324 | super().__init__(n_tasks, device=device)
325 | self.n_tasks = n_tasks
326 | self.device = device
327 | if task_weights is None:
328 | task_weights = torch.ones((n_tasks,))
329 | if not isinstance(task_weights, torch.Tensor):
330 | task_weights = torch.tensor(task_weights)
331 | assert len(task_weights) == n_tasks
332 | self.task_weights = task_weights.to(device)
333 |
334 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks
335 | self.robust_step_size = robust_step_size
336 | self.num_groups = num_groups
337 |
338 | def get_weighted_loss(self, losses, **kwargs):
339 | adjusted_loss = losses.detach()
340 | scale = adjusted_loss.sum() / adjusted_loss
341 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss)
342 | self.adv_probs = self.adv_probs / (self.adv_probs.sum())
343 | weight = scale * self.adv_probs
344 |
345 | weight = weight.unsqueeze(1)
346 |
347 | if self.num_groups >=2:
348 | # pdb.set_trace()
349 | results = SpectralClustering(n_clusters=self.num_groups,
350 | assign_labels='kmeans',
351 | affinity='linear',
352 | random_state=0).fit(weight.cpu())
353 | cluster_ids = results.labels_
354 | cluster_ids = torch.tensor(cluster_ids).unsqueeze(1).to(self.device)
355 |
356 | assignment_matrix = torch.zeros(self.n_tasks, self.num_groups).to(self.device)
357 | cluster_ids = cluster_ids.to(torch.int64)
358 | assignment_matrix.scatter_(1, cluster_ids, 1)
359 | assignment_matrix = assignment_matrix.to(torch.float64)
360 |
361 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0)
362 | cluster_centers = cluster_centers.unsqueeze(1)
363 |
364 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1
365 |
366 |
367 | elif self.num_groups == 1:
368 | group_weight = torch.ones(self.n_tasks).to(self.device)
369 | group_weight = group_weight * torch.mean(weight)
370 |
371 | loss = torch.sum(losses * group_weight)
372 | return loss, dict(weights=torch.cat([group_weight]))
373 |
374 | class GROUP_sdp_clustering(WeightMethod):
375 | def __init__(
376 | self,
377 | n_tasks: int,
378 | device: torch.device,
379 | num_groups: int,
380 | task_weights: Union[List[float], torch.Tensor] = None,
381 | robust_step_size: float = 0.0001,
382 | ):
383 | super().__init__(n_tasks, device=device)
384 | self.n_tasks = n_tasks
385 | self.device = device
386 | if task_weights is None:
387 | task_weights = torch.ones((n_tasks,))
388 | if not isinstance(task_weights, torch.Tensor):
389 | task_weights = torch.tensor(task_weights)
390 | assert len(task_weights) == n_tasks
391 | self.task_weights = task_weights.to(device)
392 |
393 | self.adv_probs = torch.ones(n_tasks).to(device) / n_tasks
394 | self.robust_step_size = robust_step_size
395 | self.num_groups = num_groups
396 |
397 | def get_weighted_loss(self, losses, **kwargs):
398 | adjusted_loss = losses.detach()
399 | scale = adjusted_loss.sum() / adjusted_loss
400 | self.adv_probs = self.adv_probs * torch.exp(- self.robust_step_size * adjusted_loss)
401 | self.adv_probs = self.adv_probs / (self.adv_probs.sum())
402 | weight = scale * self.adv_probs
403 |
404 | weight = weight.unsqueeze(1)
405 |
406 | if self.num_groups >=2:
407 | D, Q = sdp_kmeans(X=weight.cpu().numpy(), n_clusters=self.num_groups, method='cvx')
408 | assignment_matrix_init = connected_components(Q)
409 | assignment_matrix = torch.tensor(assignment_matrix_init).to(self.device)
410 | assignment_matrix = assignment_matrix.transpose(0,1).to(torch.float64)
411 |
412 | cluster_centers = (assignment_matrix * weight).sum(dim=0) / assignment_matrix.sum(0)
413 | cluster_centers = cluster_centers.unsqueeze(1)
414 |
415 | group_weight = torch.mm(assignment_matrix, cluster_centers).squeeze(1) # 3*2, 2*1
416 |
417 | elif self.num_groups == 1:
418 | group_weight = torch.ones(self.n_tasks).to(self.device)
419 | group_weight = group_weight * torch.mean(weight)
420 |
421 | loss = torch.sum(losses * group_weight)
422 | return loss, dict(weights=torch.cat([group_weight]))
423 |
424 |
--------------------------------------------------------------------------------
/experiments/nyuv2/models.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class _SegNet(nn.Module):
9 | """SegNet MTAN"""
10 |
11 | def __init__(self):
12 | super(_SegNet, self).__init__()
13 | # initialise network parameters
14 | filter = [64, 128, 256, 512, 512]
15 | self.class_nb = 13
16 |
17 | # define encoder decoder layers
18 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
19 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
20 | for i in range(4):
21 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
22 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
23 |
24 | # define convolution layer
25 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
26 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
27 | for i in range(4):
28 | if i == 0:
29 | self.conv_block_enc.append(
30 | self.conv_layer([filter[i + 1], filter[i + 1]])
31 | )
32 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
33 | else:
34 | self.conv_block_enc.append(
35 | nn.Sequential(
36 | self.conv_layer([filter[i + 1], filter[i + 1]]),
37 | self.conv_layer([filter[i + 1], filter[i + 1]]),
38 | )
39 | )
40 | self.conv_block_dec.append(
41 | nn.Sequential(
42 | self.conv_layer([filter[i], filter[i]]),
43 | self.conv_layer([filter[i], filter[i]]),
44 | )
45 | )
46 |
47 | # define task attention layers
48 | self.encoder_att = nn.ModuleList(
49 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])]
50 | )
51 | self.decoder_att = nn.ModuleList(
52 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])]
53 | )
54 | self.encoder_block_att = nn.ModuleList(
55 | [self.conv_layer([filter[0], filter[1]])]
56 | )
57 | self.decoder_block_att = nn.ModuleList(
58 | [self.conv_layer([filter[0], filter[0]])]
59 | )
60 |
61 | for j in range(3):
62 | if j < 2:
63 | self.encoder_att.append(
64 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])
65 | )
66 | self.decoder_att.append(
67 | nn.ModuleList(
68 | [self.att_layer([2 * filter[0], filter[0], filter[0]])]
69 | )
70 | )
71 | for i in range(4):
72 | self.encoder_att[j].append(
73 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]])
74 | )
75 | self.decoder_att[j].append(
76 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]])
77 | )
78 |
79 | for i in range(4):
80 | if i < 3:
81 | self.encoder_block_att.append(
82 | self.conv_layer([filter[i + 1], filter[i + 2]])
83 | )
84 | self.decoder_block_att.append(
85 | self.conv_layer([filter[i + 1], filter[i]])
86 | )
87 | else:
88 | self.encoder_block_att.append(
89 | self.conv_layer([filter[i + 1], filter[i + 1]])
90 | )
91 | self.decoder_block_att.append(
92 | self.conv_layer([filter[i + 1], filter[i + 1]])
93 | )
94 |
95 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
96 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)
97 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)
98 |
99 | # define pooling and unpooling functions
100 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
101 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | nn.init.xavier_normal_(m.weight)
106 | nn.init.constant_(m.bias, 0)
107 | elif isinstance(m, nn.BatchNorm2d):
108 | nn.init.constant_(m.weight, 1)
109 | nn.init.constant_(m.bias, 0)
110 | elif isinstance(m, nn.Linear):
111 | nn.init.xavier_normal_(m.weight)
112 | nn.init.constant_(m.bias, 0)
113 |
114 | def shared_modules(self):
115 | return [
116 | self.encoder_block,
117 | self.decoder_block,
118 | self.conv_block_enc,
119 | self.conv_block_dec,
120 | # self.encoder_att, self.decoder_att,
121 | self.encoder_block_att,
122 | self.decoder_block_att,
123 | self.down_sampling,
124 | self.up_sampling,
125 | ]
126 |
127 | def zero_grad_shared_modules(self):
128 | for mm in self.shared_modules():
129 | mm.zero_grad()
130 |
131 | def conv_layer(self, channel, pred=False):
132 | if not pred:
133 | conv_block = nn.Sequential(
134 | nn.Conv2d(
135 | in_channels=channel[0],
136 | out_channels=channel[1],
137 | kernel_size=3,
138 | padding=1,
139 | ),
140 | nn.BatchNorm2d(num_features=channel[1]),
141 | nn.ReLU(inplace=True),
142 | )
143 | else:
144 | conv_block = nn.Sequential(
145 | nn.Conv2d(
146 | in_channels=channel[0],
147 | out_channels=channel[0],
148 | kernel_size=3,
149 | padding=1,
150 | ),
151 | nn.Conv2d(
152 | in_channels=channel[0],
153 | out_channels=channel[1],
154 | kernel_size=1,
155 | padding=0,
156 | ),
157 | )
158 | return conv_block
159 |
160 | def att_layer(self, channel):
161 | att_block = nn.Sequential(
162 | nn.Conv2d(
163 | in_channels=channel[0],
164 | out_channels=channel[1],
165 | kernel_size=1,
166 | padding=0,
167 | ),
168 | nn.BatchNorm2d(channel[1]),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(
171 | in_channels=channel[1],
172 | out_channels=channel[2],
173 | kernel_size=1,
174 | padding=0,
175 | ),
176 | nn.BatchNorm2d(channel[2]),
177 | nn.Sigmoid(),
178 | )
179 | return att_block
180 |
181 | def forward(self, x):
182 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
183 | [0] * 5 for _ in range(5)
184 | )
185 | for i in range(5):
186 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
187 |
188 | # define attention list for tasks
189 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2))
190 | for i in range(3):
191 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
192 | for i in range(3):
193 | for j in range(5):
194 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))
195 |
196 | # define global shared network
197 | for i in range(5):
198 | if i == 0:
199 | g_encoder[i][0] = self.encoder_block[i](x)
200 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
201 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
202 | else:
203 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
204 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
205 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
206 |
207 | for i in range(5):
208 | if i == 0:
209 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
210 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
211 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
212 | else:
213 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
214 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
215 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
216 |
217 | # define task dependent attention module
218 | for i in range(3):
219 | for j in range(5):
220 | if j == 0:
221 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
222 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
223 | atten_encoder[i][j][2] = self.encoder_block_att[j](
224 | atten_encoder[i][j][1]
225 | )
226 | atten_encoder[i][j][2] = F.max_pool2d(
227 | atten_encoder[i][j][2], kernel_size=2, stride=2
228 | )
229 | else:
230 | atten_encoder[i][j][0] = self.encoder_att[i][j](
231 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1)
232 | )
233 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
234 | atten_encoder[i][j][2] = self.encoder_block_att[j](
235 | atten_encoder[i][j][1]
236 | )
237 | atten_encoder[i][j][2] = F.max_pool2d(
238 | atten_encoder[i][j][2], kernel_size=2, stride=2
239 | )
240 |
241 | for j in range(5):
242 | if j == 0:
243 | atten_decoder[i][j][0] = F.interpolate(
244 | atten_encoder[i][-1][-1],
245 | scale_factor=2,
246 | mode="bilinear",
247 | align_corners=True,
248 | )
249 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
250 | atten_decoder[i][j][0]
251 | )
252 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
253 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
254 | )
255 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
256 | else:
257 | atten_decoder[i][j][0] = F.interpolate(
258 | atten_decoder[i][j - 1][2],
259 | scale_factor=2,
260 | mode="bilinear",
261 | align_corners=True,
262 | )
263 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
264 | atten_decoder[i][j][0]
265 | )
266 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
267 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
268 | )
269 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
270 |
271 | # define task prediction layers
272 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1)
273 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1])
274 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1])
275 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
276 |
277 | return (
278 | [t1_pred, t2_pred, t3_pred],
279 | (
280 | atten_decoder[0][-1][-1],
281 | atten_decoder[1][-1][-1],
282 | atten_decoder[2][-1][-1],
283 | ),
284 | )
285 |
286 |
287 | class SegNetMtan(nn.Module):
288 | def __init__(self):
289 | super().__init__()
290 | self.segnet = _SegNet()
291 |
292 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
293 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
294 |
295 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
296 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
297 |
298 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
299 | """Parameters of the last shared layer.
300 | Returns
301 | -------
302 | """
303 | return []
304 |
305 | def forward(self, x, return_representation=False):
306 | if return_representation:
307 | return self.segnet(x)
308 | else:
309 | pred, rep = self.segnet(x)
310 | return pred
311 |
312 |
313 | class SegNetSplit(nn.Module):
314 | def __init__(self, model_type="standard"):
315 | super(SegNetSplit, self).__init__()
316 | # initialise network parameters
317 | assert model_type in ["standard", "wide", "deep"]
318 | self.model_type = model_type
319 | if self.model_type == "wide":
320 | filter = [64, 128, 256, 512, 1024]
321 | else:
322 | filter = [64, 128, 256, 512, 512]
323 |
324 | self.class_nb = 13
325 |
326 | # define encoder decoder layers
327 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
328 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
329 | for i in range(4):
330 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
331 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
332 |
333 | # define convolution layer
334 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
335 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
336 | for i in range(4):
337 | if i == 0:
338 | self.conv_block_enc.append(
339 | self.conv_layer([filter[i + 1], filter[i + 1]])
340 | )
341 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
342 | else:
343 | self.conv_block_enc.append(
344 | nn.Sequential(
345 | self.conv_layer([filter[i + 1], filter[i + 1]]),
346 | self.conv_layer([filter[i + 1], filter[i + 1]]),
347 | )
348 | )
349 | self.conv_block_dec.append(
350 | nn.Sequential(
351 | self.conv_layer([filter[i], filter[i]]),
352 | self.conv_layer([filter[i], filter[i]]),
353 | )
354 | )
355 |
356 | # define task specific layers
357 | self.pred_task1 = nn.Sequential(
358 | nn.Conv2d(
359 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
360 | ),
361 | nn.Conv2d(
362 | in_channels=filter[0],
363 | out_channels=self.class_nb,
364 | kernel_size=1,
365 | padding=0,
366 | ),
367 | )
368 | self.pred_task2 = nn.Sequential(
369 | nn.Conv2d(
370 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
371 | ),
372 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0),
373 | )
374 | self.pred_task3 = nn.Sequential(
375 | nn.Conv2d(
376 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
377 | ),
378 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0),
379 | )
380 |
381 | # define pooling and unpooling functions
382 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
383 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
384 |
385 | for m in self.modules():
386 | if isinstance(m, nn.Conv2d):
387 | nn.init.xavier_normal_(m.weight)
388 | nn.init.constant_(m.bias, 0)
389 | elif isinstance(m, nn.BatchNorm2d):
390 | nn.init.constant_(m.weight, 1)
391 | nn.init.constant_(m.bias, 0)
392 | elif isinstance(m, nn.Linear):
393 | nn.init.xavier_normal_(m.weight)
394 | nn.init.constant_(m.bias, 0)
395 |
396 | # define convolutional block
397 | def conv_layer(self, channel):
398 | if self.model_type == "deep":
399 | conv_block = nn.Sequential(
400 | nn.Conv2d(
401 | in_channels=channel[0],
402 | out_channels=channel[1],
403 | kernel_size=3,
404 | padding=1,
405 | ),
406 | nn.BatchNorm2d(num_features=channel[1]),
407 | nn.ReLU(inplace=True),
408 | nn.Conv2d(
409 | in_channels=channel[1],
410 | out_channels=channel[1],
411 | kernel_size=3,
412 | padding=1,
413 | ),
414 | nn.BatchNorm2d(num_features=channel[1]),
415 | nn.ReLU(inplace=True),
416 | )
417 | else:
418 | conv_block = nn.Sequential(
419 | nn.Conv2d(
420 | in_channels=channel[0],
421 | out_channels=channel[1],
422 | kernel_size=3,
423 | padding=1,
424 | ),
425 | nn.BatchNorm2d(num_features=channel[1]),
426 | nn.ReLU(inplace=True),
427 | )
428 | return conv_block
429 |
430 | def forward(self, x):
431 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
432 | [0] * 5 for _ in range(5)
433 | )
434 | for i in range(5):
435 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
436 |
437 | # global shared encoder-decoder network
438 | for i in range(5):
439 | if i == 0:
440 | g_encoder[i][0] = self.encoder_block[i](x)
441 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
442 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
443 | else:
444 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
445 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
446 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
447 |
448 | for i in range(5):
449 | if i == 0:
450 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
451 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
452 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
453 | else:
454 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
455 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
456 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
457 |
458 | # define task prediction layers
459 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1)
460 | t2_pred = self.pred_task2(g_decoder[i][1])
461 | t3_pred = self.pred_task3(g_decoder[i][1])
462 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
463 |
464 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][
465 | 1
466 | ] # NOTE: last element is representation
467 |
468 |
469 | class SegNet(nn.Module):
470 | def __init__(self):
471 | super().__init__()
472 | self.segnet = SegNetSplit()
473 |
474 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
475 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
476 |
477 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
478 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
479 |
480 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
481 | """Parameters of the last shared layer.
482 | Returns
483 | -------
484 | """
485 | return self.segnet.conv_block_dec[-5].parameters()
486 |
487 | def forward(self, x, return_representation=False):
488 | if return_representation:
489 | return self.segnet(x)
490 | else:
491 | pred, rep = self.segnet(x)
492 | return pred
493 |
--------------------------------------------------------------------------------
/methods/weight_methods.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from abc import abstractmethod
4 | from typing import Dict, List, Tuple, Union
5 |
6 | import cvxpy as cp
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from scipy.optimize import minimize
11 |
12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers
13 | from methods.cluster_methods import *
14 |
15 | class WeightMethod:
16 | def __init__(self, n_tasks: int, device: torch.device):
17 | super().__init__()
18 | self.n_tasks = n_tasks
19 | self.device = device
20 |
21 | @abstractmethod
22 | def get_weighted_loss(
23 | self,
24 | losses: torch.Tensor,
25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
26 | task_specific_parameters: Union[
27 | List[torch.nn.parameter.Parameter], torch.Tensor
28 | ],
29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
31 | **kwargs,
32 | ):
33 | pass
34 |
35 | def backward(
36 | self,
37 | losses: torch.Tensor,
38 | shared_parameters: Union[
39 | List[torch.nn.parameter.Parameter], torch.Tensor
40 | ] = None,
41 | task_specific_parameters: Union[
42 | List[torch.nn.parameter.Parameter], torch.Tensor
43 | ] = None,
44 | last_shared_parameters: Union[
45 | List[torch.nn.parameter.Parameter], torch.Tensor
46 | ] = None,
47 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
48 | **kwargs,
49 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
50 | """
51 |
52 | Parameters
53 | ----------
54 | losses :
55 | shared_parameters :
56 | task_specific_parameters :
57 | last_shared_parameters : parameters of last shared layer/block
58 | representation : shared representation
59 | kwargs :
60 |
61 | Returns
62 | -------
63 | Loss, extra outputs
64 | """
65 | loss, extra_outputs = self.get_weighted_loss(
66 | losses=losses,
67 | shared_parameters=shared_parameters,
68 | task_specific_parameters=task_specific_parameters,
69 | last_shared_parameters=last_shared_parameters,
70 | representation=representation,
71 | **kwargs,
72 | )
73 | loss.backward()
74 | return loss, extra_outputs
75 |
76 | def __call__(
77 | self,
78 | losses: torch.Tensor,
79 | shared_parameters: Union[
80 | List[torch.nn.parameter.Parameter], torch.Tensor
81 | ] = None,
82 | task_specific_parameters: Union[
83 | List[torch.nn.parameter.Parameter], torch.Tensor
84 | ] = None,
85 | **kwargs,
86 | ):
87 | return self.backward(
88 | losses=losses,
89 | shared_parameters=shared_parameters,
90 | task_specific_parameters=task_specific_parameters,
91 | **kwargs,
92 | )
93 |
94 | def parameters(self) -> List[torch.Tensor]:
95 | """return learnable parameters"""
96 | return []
97 |
98 | class NashMTL(WeightMethod):
99 | def __init__(
100 | self,
101 | n_tasks: int,
102 | device: torch.device,
103 | max_norm: float = 1.0,
104 | update_weights_every: int = 1,
105 | optim_niter=20,
106 | ):
107 | super(NashMTL, self).__init__(
108 | n_tasks=n_tasks,
109 | device=device,
110 | )
111 |
112 | self.optim_niter = optim_niter
113 | self.update_weights_every = update_weights_every
114 | self.max_norm = max_norm
115 |
116 | self.prvs_alpha_param = None
117 | self.normalization_factor = np.ones((1,))
118 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks)
119 | self.step = 0.0
120 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
121 |
122 | def _stop_criteria(self, gtg, alpha_t):
123 | return (
124 | (self.alpha_param.value is None)
125 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
126 | or (
127 | np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value)
128 | < 1e-6
129 | )
130 | )
131 |
132 | def solve_optimization(self, gtg: np.array):
133 | self.G_param.value = gtg
134 | self.normalization_factor_param.value = self.normalization_factor
135 |
136 | alpha_t = self.prvs_alpha
137 | for _ in range(self.optim_niter):
138 | self.alpha_param.value = alpha_t
139 | self.prvs_alpha_param.value = alpha_t
140 |
141 | try:
142 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
143 | except:
144 | self.alpha_param.value = self.prvs_alpha_param.value
145 |
146 | if self._stop_criteria(gtg, alpha_t):
147 | break
148 |
149 | alpha_t = self.alpha_param.value
150 |
151 | if alpha_t is not None:
152 | self.prvs_alpha = alpha_t
153 |
154 | return self.prvs_alpha
155 |
156 | def _calc_phi_alpha_linearization(self):
157 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param
158 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param
159 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param)
160 | return phi_alpha
161 |
162 | def _init_optim_problem(self):
163 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True)
164 | self.prvs_alpha_param = cp.Parameter(
165 | shape=(self.n_tasks,), value=self.prvs_alpha
166 | )
167 | self.G_param = cp.Parameter(
168 | shape=(self.n_tasks, self.n_tasks), value=self.init_gtg
169 | )
170 | self.normalization_factor_param = cp.Parameter(
171 | shape=(1,), value=np.array([1.0])
172 | )
173 |
174 | self.phi_alpha = self._calc_phi_alpha_linearization()
175 |
176 | G_alpha = self.G_param @ self.alpha_param
177 | constraint = []
178 | for i in range(self.n_tasks):
179 | constraint.append(
180 | -cp.log(self.alpha_param[i] * self.normalization_factor_param)
181 | - cp.log(G_alpha[i])
182 | <= 0
183 | )
184 | obj = cp.Minimize(
185 | cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param
186 | )
187 | self.prob = cp.Problem(obj, constraint)
188 |
189 | def get_weighted_loss(
190 | self,
191 | losses,
192 | shared_parameters,
193 | **kwargs,
194 | ):
195 | """
196 |
197 | Parameters
198 | ----------
199 | losses :
200 | shared_parameters : shared parameters
201 | kwargs :
202 |
203 | Returns
204 | -------
205 |
206 | """
207 |
208 | extra_outputs = dict()
209 | if self.step == 0:
210 | self._init_optim_problem()
211 |
212 | if (self.step % self.update_weights_every) == 0:
213 | self.step += 1
214 |
215 | grads = {}
216 | for i, loss in enumerate(losses):
217 | g = list(
218 | torch.autograd.grad(
219 | loss,
220 | shared_parameters,
221 | retain_graph=True,
222 | )
223 | )
224 | grad = torch.cat([torch.flatten(grad) for grad in g])
225 | grads[i] = grad # 44117184
226 |
227 | G = torch.stack(tuple(v for v in grads.values())) # 3 * 44117184
228 | GTG = torch.mm(G, G.t()) # 3 * 3
229 |
230 | self.normalization_factor = (torch.norm(GTG).detach().cpu().numpy().reshape((1,)))
231 | GTG = GTG / self.normalization_factor.item() #!!!
232 | alpha = self.solve_optimization(GTG.cpu().detach().numpy())
233 | alpha = torch.from_numpy(alpha)
234 |
235 | else:
236 | self.step += 1
237 | alpha = self.prvs_alpha
238 | # -----------------
239 | # updated_gradient = alpha.unsqueeze(1).float().cuda() * G
240 | # updated_overall = updated_gradient.sum(0)
241 | # # print(updated_overall.norm() - np.sqrt(3))
242 | # projection = updated_overall @ torch.transpose(updated_gradient, 0, 1)
243 | # # print(projection)
244 | # # cosine = projection / (updated_overall.norm() * updated_gradient.norm(dim=1))
245 | # norm_gradient = updated_gradient / updated_gradient.norm(dim=1).unsqueeze(1)
246 | #
247 | # norm_gradient @ torch.transpose(norm_gradient, 0, 1)
248 | # norm_G = G / G.norm(dim=1).unsqueeze(1)
249 | # norm_G @ torch.transpose(norm_G, 0, 1)
250 | # -----------------
251 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])
252 | extra_outputs["weights"] = alpha
253 | return weighted_loss, extra_outputs
254 |
255 | def backward(
256 | self,
257 | losses: torch.Tensor,
258 | shared_parameters: Union[
259 | List[torch.nn.parameter.Parameter], torch.Tensor
260 | ] = None,
261 | task_specific_parameters: Union[
262 | List[torch.nn.parameter.Parameter], torch.Tensor
263 | ] = None,
264 | last_shared_parameters: Union[
265 | List[torch.nn.parameter.Parameter], torch.Tensor
266 | ] = None,
267 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
268 | **kwargs,
269 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
270 | loss, extra_outputs = self.get_weighted_loss(
271 | losses=losses,
272 | shared_parameters=shared_parameters,
273 | **kwargs,
274 | )
275 | loss.backward()
276 |
277 | # make sure the solution for shared params has norm <= self.eps
278 | if self.max_norm > 0:
279 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm)
280 |
281 | return loss, extra_outputs
282 |
283 | class LinearScalarization(WeightMethod):
284 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
285 |
286 | def __init__(
287 | self,
288 | n_tasks: int,
289 | device: torch.device,
290 | task_weights: Union[List[float], torch.Tensor] = None,
291 | ):
292 | super().__init__(n_tasks, device=device)
293 | if task_weights is None:
294 | task_weights = torch.ones((n_tasks,))
295 | if not isinstance(task_weights, torch.Tensor):
296 | task_weights = torch.tensor(task_weights)
297 | assert len(task_weights) == n_tasks
298 | self.task_weights = task_weights.to(device)
299 |
300 | def get_weighted_loss(self, losses, **kwargs):
301 | loss = torch.sum(losses * self.task_weights)
302 | return loss, dict(weights=self.task_weights)
303 |
304 | class ScaleInvariantLinearScalarization(WeightMethod):
305 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
306 |
307 | def __init__(
308 | self,
309 | n_tasks: int,
310 | device: torch.device,
311 | task_weights: Union[List[float], torch.Tensor] = None,
312 | ):
313 | super().__init__(n_tasks, device=device)
314 | if task_weights is None:
315 | task_weights = torch.ones((n_tasks,))
316 | if not isinstance(task_weights, torch.Tensor):
317 | task_weights = torch.tensor(task_weights)
318 | assert len(task_weights) == n_tasks
319 | self.task_weights = task_weights.to(device)
320 |
321 | def get_weighted_loss(self, losses, **kwargs):
322 | loss = torch.sum(torch.log(losses) * self.task_weights)
323 | return loss, dict(weights=self.task_weights)
324 |
325 | class MGDA(WeightMethod):
326 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization
327 | Ozan Sener, Vladlen Koltun
328 | Neural Information Processing Systems (NeurIPS) 2018
329 | https://github.com/intel-isl/MultiObjectiveOptimization
330 |
331 | """
332 |
333 | def __init__(
334 | self, n_tasks, device: torch.device, params="shared", normalization="none"
335 | ):
336 | super().__init__(n_tasks, device=device)
337 | self.solver = MinNormSolver()
338 | assert params in ["shared", "last", "rep"]
339 | self.params = params
340 | assert normalization in ["norm", "loss", "loss+", "none"]
341 | self.normalization = normalization
342 |
343 | @staticmethod
344 | def _flattening(grad):
345 | return torch.cat(
346 | tuple(
347 | g.reshape(
348 | -1,
349 | )
350 | for i, g in enumerate(grad)
351 | ),
352 | dim=0,
353 | )
354 |
355 | def get_weighted_loss(
356 | self,
357 | losses,
358 | shared_parameters=None,
359 | last_shared_parameters=None,
360 | representation=None,
361 | **kwargs,
362 | ):
363 | """
364 |
365 | Parameters
366 | ----------
367 | losses :
368 | shared_parameters :
369 | last_shared_parameters :
370 | representation :
371 | kwargs :
372 |
373 | Returns
374 | -------
375 |
376 | """
377 | # Our code
378 | grads = {}
379 | params = dict(
380 | rep=representation, shared=shared_parameters, last=last_shared_parameters
381 | )[self.params]
382 | for i, loss in enumerate(losses):
383 | g = list(
384 | torch.autograd.grad(
385 | loss,
386 | params,
387 | retain_graph=True,
388 | )
389 | )
390 | # Normalize all gradients, this is optional and not included in the paper.
391 |
392 | grads[i] = [torch.flatten(grad) for grad in g]
393 |
394 | gn = gradient_normalizers(grads, losses, self.normalization)
395 | for t in range(self.n_tasks):
396 | for gr_i in range(len(grads[t])):
397 | grads[t][gr_i] = grads[t][gr_i] / gn[t]
398 |
399 | sol, min_norm = self.solver.find_min_norm_element(
400 | [grads[t] for t in range(len(grads))]
401 | )
402 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks
403 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))])
404 |
405 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32)))
406 |
407 | class STL(WeightMethod):
408 | """Single task learning"""
409 |
410 | def __init__(self, n_tasks, device: torch.device, main_task):
411 | super().__init__(n_tasks, device=device)
412 | self.main_task = main_task
413 | self.weights = torch.zeros(n_tasks, device=device)
414 | self.weights[main_task] = 1.0
415 |
416 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
417 | assert len(losses) == self.n_tasks
418 | loss = losses[self.main_task]
419 |
420 | return loss, dict(weights=self.weights)
421 |
422 | class Uncertainty(WeightMethod):
423 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics`
424 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
425 | """
426 |
427 | def __init__(self, n_tasks, device: torch.device):
428 | super().__init__(n_tasks, device=device)
429 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True)
430 |
431 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
432 | loss = sum(
433 | [
434 | 0.5 * (torch.exp(-logs) * loss + logs)
435 | for loss, logs in zip(losses, self.logsigma)
436 | ]
437 | )
438 |
439 | return loss, dict(
440 | weights=torch.exp(-self.logsigma)
441 | ) # NOTE: not exactly task weights
442 |
443 | def parameters(self) -> List[torch.Tensor]:
444 | return [self.logsigma]
445 |
446 | class PCGrad(WeightMethod):
447 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py
448 |
449 | @misc{Pytorch-PCGrad,
450 | author = {Wei-Cheng Tseng},
451 | title = {WeiChengTseng/Pytorch-PCGrad},
452 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git},
453 | year = {2020}
454 | }
455 |
456 | """
457 |
458 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"):
459 | super().__init__(n_tasks, device=device)
460 | assert reduction in ["mean", "sum"]
461 | self.reduction = reduction
462 |
463 | def get_weighted_loss(
464 | self,
465 | losses: torch.Tensor,
466 | shared_parameters: Union[
467 | List[torch.nn.parameter.Parameter], torch.Tensor
468 | ] = None,
469 | task_specific_parameters: Union[
470 | List[torch.nn.parameter.Parameter], torch.Tensor
471 | ] = None,
472 | **kwargs,
473 | ):
474 | raise NotImplementedError
475 |
476 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None):
477 | # shared part
478 | shared_grads = []
479 | for l in losses:
480 | shared_grads.append(
481 | torch.autograd.grad(l, shared_parameters, retain_graph=True)
482 | )
483 |
484 | if isinstance(shared_parameters, torch.Tensor):
485 | shared_parameters = [shared_parameters]
486 | non_conflict_shared_grads = self._project_conflicting(shared_grads)
487 | for p, g in zip(shared_parameters, non_conflict_shared_grads):
488 | p.grad = g
489 |
490 | # task specific part
491 | if task_specific_parameters is not None:
492 | task_specific_grads = torch.autograd.grad(
493 | losses.sum(), task_specific_parameters
494 | )
495 | if isinstance(task_specific_parameters, torch.Tensor):
496 | task_specific_parameters = [task_specific_parameters]
497 | for p, g in zip(task_specific_parameters, task_specific_grads):
498 | p.grad = g
499 |
500 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]):
501 | pc_grad = copy.deepcopy(grads)
502 | for g_i in pc_grad:
503 | random.shuffle(grads)
504 | for g_j in grads:
505 | g_i_g_j = sum(
506 | [
507 | torch.dot(torch.flatten(grad_i), torch.flatten(grad_j))
508 | for grad_i, grad_j in zip(g_i, g_j)
509 | ]
510 | )
511 | if g_i_g_j < 0:
512 | g_j_norm_square = (
513 | torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2
514 | )
515 | for grad_i, grad_j in zip(g_i, g_j):
516 | grad_i -= g_i_g_j * grad_j / g_j_norm_square
517 |
518 | merged_grad = [sum(g) for g in zip(*pc_grad)]
519 | if self.reduction == "mean":
520 | merged_grad = [g / self.n_tasks for g in merged_grad]
521 |
522 | return merged_grad
523 |
524 | def backward(
525 | self,
526 | losses: torch.Tensor,
527 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
528 | shared_parameters: Union[
529 | List[torch.nn.parameter.Parameter], torch.Tensor
530 | ] = None,
531 | task_specific_parameters: Union[
532 | List[torch.nn.parameter.Parameter], torch.Tensor
533 | ] = None,
534 | **kwargs,
535 | ):
536 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters)
537 | return None, {} # NOTE: to align with all other weight methods
538 |
539 | class CAGrad(WeightMethod):
540 | def __init__(self, n_tasks, device: torch.device, c=0.4):
541 | super().__init__(n_tasks, device=device)
542 | self.c = c
543 |
544 | def get_weighted_loss(
545 | self,
546 | losses,
547 | shared_parameters,
548 | **kwargs,
549 | ):
550 | """
551 | Parameters
552 | ----------
553 | losses :
554 | shared_parameters : shared parameters
555 | kwargs :
556 | Returns
557 | -------
558 | """
559 | # NOTE: we allow only shared params for now. Need to see paper for other options.
560 | grad_dims = []
561 | for param in shared_parameters:
562 | grad_dims.append(param.data.numel())
563 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device)
564 |
565 | for i in range(self.n_tasks):
566 | if i < (self.n_tasks - 1):
567 | losses[i].backward(retain_graph=True)
568 | else:
569 | losses[i].backward()
570 | self.grad2vec(shared_parameters, grads, grad_dims, i)
571 | # multi_task_model.zero_grad_shared_modules()
572 | for p in shared_parameters:
573 | p.grad = None
574 |
575 | g = self.cagrad(grads, alpha=self.c, rescale=1)
576 | self.overwrite_grad(shared_parameters, g, grad_dims)
577 |
578 | def cagrad(self, grads, alpha=0.5, rescale=1):
579 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks]
580 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient
581 |
582 | x_start = np.ones(self.n_tasks) / self.n_tasks
583 | bnds = tuple((0, 1) for x in x_start)
584 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)}
585 | A = GG.numpy()
586 | b = x_start.copy()
587 | c = (alpha * g0_norm + 1e-8).item()
588 |
589 | def objfn(x):
590 | return (
591 | x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1))
592 | + c
593 | * np.sqrt(
594 | x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1))
595 | + 1e-8
596 | )
597 | ).sum()
598 |
599 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
600 | w_cpu = res.x
601 | ww = torch.Tensor(w_cpu).to(grads.device)
602 | gw = (grads * ww.view(1, -1)).sum(1)
603 | gw_norm = gw.norm()
604 | lmbda = c / (gw_norm + 1e-8)
605 | g = grads.mean(1) + lmbda * gw
606 | if rescale == 0:
607 | return g
608 | elif rescale == 1:
609 | return g / (1 + alpha ** 2)
610 | else:
611 | return g / (1 + alpha)
612 |
613 | @staticmethod
614 | def grad2vec(shared_params, grads, grad_dims, task):
615 | # store the gradients
616 | grads[:, task].fill_(0.0)
617 | cnt = 0
618 | # for mm in m.shared_modules():
619 | # for p in mm.parameters():
620 |
621 | for param in shared_params:
622 | grad = param.grad
623 | if grad is not None:
624 | grad_cur = grad.data.detach().clone()
625 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
626 | en = sum(grad_dims[: cnt + 1])
627 | grads[beg:en, task].copy_(grad_cur.data.view(-1))
628 | cnt += 1
629 |
630 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims):
631 | newgrad = newgrad * self.n_tasks # to match the sum loss
632 | cnt = 0
633 |
634 | # for mm in m.shared_modules():
635 | # for param in mm.parameters():
636 | for param in shared_parameters:
637 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
638 | en = sum(grad_dims[: cnt + 1])
639 | this_grad = newgrad[beg:en].contiguous().view(param.data.size())
640 | param.grad = this_grad.data.clone()
641 | cnt += 1
642 |
643 | def backward(
644 | self,
645 | losses: torch.Tensor,
646 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
647 | shared_parameters: Union[
648 | List[torch.nn.parameter.Parameter], torch.Tensor
649 | ] = None,
650 | task_specific_parameters: Union[
651 | List[torch.nn.parameter.Parameter], torch.Tensor
652 | ] = None,
653 | **kwargs,
654 | ):
655 | self.get_weighted_loss(losses, shared_parameters)
656 | return None, {} # NOTE: to align with all other weight methods
657 |
658 | class RLW(WeightMethod):
659 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf"""
660 |
661 | def __init__(self, n_tasks, device: torch.device):
662 | super().__init__(n_tasks, device=device)
663 |
664 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
665 | assert len(losses) == self.n_tasks
666 | weight = (F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
667 | loss = torch.sum(losses * weight)
668 |
669 | return loss, dict(weights=weight)
670 |
671 | class IMTLG(WeightMethod):
672 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr"""
673 |
674 | def __init__(self, n_tasks, device: torch.device):
675 | super().__init__(n_tasks, device=device)
676 |
677 | def get_weighted_loss(
678 | self,
679 | losses,
680 | shared_parameters,
681 | **kwargs,
682 | ):
683 | grads = {}
684 | norm_grads = {}
685 |
686 | for i, loss in enumerate(losses):
687 | g = list(
688 | torch.autograd.grad(
689 | loss,
690 | shared_parameters,
691 | retain_graph=True,
692 | )
693 | )
694 | grad = torch.cat([torch.flatten(grad) for grad in g])
695 | norm_term = torch.norm(grad)
696 |
697 | grads[i] = grad
698 | norm_grads[i] = grad / norm_term
699 |
700 | G = torch.stack(tuple(v for v in grads.values()))
701 | D = (
702 | G[
703 | 0,
704 | ]
705 | - G[
706 | 1:,
707 | ]
708 | )
709 |
710 | U = torch.stack(tuple(v for v in norm_grads.values()))
711 | U = (
712 | U[
713 | 0,
714 | ]
715 | - U[
716 | 1:,
717 | ]
718 | )
719 | first_element = torch.matmul(
720 | G[
721 | 0,
722 | ],
723 | U.t(),
724 | )
725 | try:
726 | second_element = torch.inverse(torch.matmul(D, U.t()))
727 | except:
728 | # workaround for cases where matrix is singular
729 | second_element = torch.inverse(
730 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8
731 | + torch.matmul(D, U.t())
732 | )
733 |
734 | alpha_ = torch.matmul(first_element, second_element)
735 | alpha = torch.cat(
736 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_)
737 | )
738 |
739 | loss = torch.sum(losses * alpha)
740 |
741 | return loss, dict(weights=alpha)
742 |
743 | class DynamicWeightAverage(WeightMethod):
744 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`.
745 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242
746 | """
747 |
748 | def __init__(
749 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0
750 | ):
751 | """
752 |
753 | Parameters
754 | ----------
755 | n_tasks :
756 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses
757 | temp :
758 | """
759 | super().__init__(n_tasks, device=device)
760 | self.iteration_window = iteration_window
761 | self.temp = temp
762 | self.running_iterations = 0
763 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32)
764 | self.weights = np.ones(n_tasks, dtype=np.float32)
765 |
766 | def get_weighted_loss(self, losses, **kwargs):
767 |
768 | cost = losses.detach().cpu().numpy()
769 |
770 | # update costs - fifo
771 | self.costs[:-1, :] = self.costs[1:, :]
772 | self.costs[-1, :] = cost
773 |
774 | if self.running_iterations > self.iteration_window:
775 | ws = self.costs[self.iteration_window :, :].mean(0) / self.costs[
776 | : self.iteration_window, :
777 | ].mean(0)
778 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (
779 | np.exp(ws / self.temp)
780 | ).sum()
781 |
782 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(
783 | losses.device
784 | )
785 | loss = (task_weights * losses).mean()
786 |
787 | self.running_iterations += 1
788 |
789 | return loss, dict(weights=task_weights)
790 |
791 | class FAMO(WeightMethod):
792 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
793 |
794 | def __init__(
795 | self,
796 | n_tasks: int,
797 | device: torch.device,
798 | gamma: float = 1e-5,
799 | w_lr: float = 0.025,
800 | task_weights: Union[List[float], torch.Tensor] = None,
801 | max_norm: float = 1.0,
802 | ):
803 | super().__init__(n_tasks, device=device)
804 | self.min_losses = torch.zeros(n_tasks).to(device)
805 | self.w = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True)
806 | self.w_opt = torch.optim.Adam([self.w], lr=w_lr, weight_decay=gamma)
807 | self.max_norm = max_norm
808 |
809 | def set_min_losses(self, losses):
810 | self.min_losses = losses
811 |
812 | def get_weighted_loss(self, losses, **kwargs):
813 | self.prev_loss = losses
814 | z = F.softmax(self.w, -1)
815 | D = losses - self.min_losses + 1e-8
816 | c = (z / D).sum().detach()
817 | loss = (D.log() * z / c).sum()
818 | # return loss, {"weights": z, "logits": self.w.detach().clone()}
819 | return loss, dict(weights=torch.cat([z]))
820 |
821 | def update(self, curr_loss):
822 | delta = (self.prev_loss - self.min_losses + 1e-8).log() - \
823 | (curr_loss - self.min_losses + 1e-8).log()
824 | with torch.enable_grad():
825 | d = torch.autograd.grad(F.softmax(self.w, -1),
826 | self.w,
827 | grad_outputs=delta.detach())[0]
828 | self.w_opt.zero_grad()
829 | self.w.grad = d
830 | self.w_opt.step()
831 |
832 | class GradDrop(WeightMethod):
833 | def __init__(self, n_tasks, device: torch.device, max_norm=1.0):
834 | super().__init__(n_tasks, device=device)
835 | self.max_norm = max_norm
836 |
837 | def get_weighted_loss(
838 | self,
839 | losses,
840 | shared_parameters,
841 | **kwargs,
842 | ):
843 | """
844 | Parameters
845 | ----------
846 | losses :
847 | shared_parameters : shared parameters
848 | kwargs :
849 | Returns
850 | -------
851 | """
852 | # NOTE: we allow only shared params for now. Need to see paper for other options.
853 | grad_dims = []
854 | for param in shared_parameters:
855 | grad_dims.append(param.data.numel())
856 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device)
857 |
858 | for i in range(self.n_tasks):
859 | if i < self.n_tasks:
860 | losses[i].backward(retain_graph=True)
861 | else:
862 | losses[i].backward()
863 | self.grad2vec(shared_parameters, grads, grad_dims, i)
864 | # multi_task_model.zero_grad_shared_modules()
865 | for p in shared_parameters:
866 | p.grad = None
867 |
868 | P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1)+1e-8))
869 | U = torch.rand_like(grads[:,0])
870 | M = P.gt(U).view(-1,1)*grads.gt(0) + P.lt(U).view(-1,1)*grads.lt(0)
871 | g = (grads * M.float()).mean(1)
872 | self.overwrite_grad(shared_parameters, g, grad_dims)
873 |
874 | @staticmethod
875 | def grad2vec(shared_params, grads, grad_dims, task):
876 | # store the gradients
877 | grads[:, task].fill_(0.0)
878 | cnt = 0
879 | # for mm in m.shared_modules():
880 | # for p in mm.parameters():
881 |
882 | for param in shared_params:
883 | grad = param.grad
884 | if grad is not None:
885 | grad_cur = grad.data.detach().clone()
886 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
887 | en = sum(grad_dims[: cnt + 1])
888 | grads[beg:en, task].copy_(grad_cur.data.view(-1))
889 | cnt += 1
890 |
891 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims):
892 | newgrad = newgrad * self.n_tasks # to match the sum loss
893 | cnt = 0
894 |
895 | # for mm in m.shared_modules():
896 | # for param in mm.parameters():
897 | for param in shared_parameters:
898 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
899 | en = sum(grad_dims[: cnt + 1])
900 | this_grad = newgrad[beg:en].contiguous().view(param.data.size())
901 | param.grad = this_grad.data.clone()
902 | cnt += 1
903 |
904 | def backward(
905 | self,
906 | losses: torch.Tensor,
907 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
908 | shared_parameters: Union[
909 | List[torch.nn.parameter.Parameter], torch.Tensor
910 | ] = None,
911 | task_specific_parameters: Union[
912 | List[torch.nn.parameter.Parameter], torch.Tensor
913 | ] = None,
914 | **kwargs,
915 | ):
916 | #GTG, w = self.get_weighted_loss(losses, shared_parameters)
917 | self.get_weighted_loss(losses, shared_parameters)
918 | if self.max_norm > 0:
919 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm)
920 | return None, None # NOTE: to align with all other weight methods
921 |
922 | class WeightMethods:
923 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs):
924 | """
925 | :param method:
926 | """
927 | assert method in list(METHODS.keys()), f"unknown method {method}."
928 |
929 | self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs)
930 |
931 | def get_weighted_loss(self, losses, **kwargs):
932 | return self.method.get_weighted_loss(losses, **kwargs)
933 |
934 | def backward(
935 | self, losses, **kwargs
936 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
937 | return self.method.backward(losses, **kwargs)
938 |
939 | def __ceil__(self, losses, **kwargs):
940 | return self.backward(losses, **kwargs)
941 |
942 | def parameters(self):
943 | return self.method.parameters()
944 |
945 | METHODS = dict(
946 | stl=STL,
947 | ls=LinearScalarization,
948 | uw=Uncertainty,
949 | pcgrad=PCGrad,
950 | mgda=MGDA,
951 | cagrad=CAGrad,
952 | nashmtl=NashMTL,
953 | scaleinvls=ScaleInvariantLinearScalarization,
954 | rlw=RLW,
955 | imtl=IMTLG,
956 | dwa=DynamicWeightAverage,
957 | #=========
958 | graddrop=GradDrop,
959 | famo=FAMO,
960 | go4align=GO4ALIGN,
961 | group=GROUP,
962 | group_random=GROUP_RANDOM,
963 | #=========
964 | group_sklearn_spectral_clustering_cluster_qr=GROUP_sklearn_spectral_clustering_cluster_qr,
965 | group_sklearn_spectral_clustering_discretize=GROUP_sklearn_spectral_clustering_discretize,
966 | group_sklearn_spectral_clustering_kmeans=GROUP_sklearn_spectral_clustering_kmeans,
967 | group_sdp_clustering=GROUP_sdp_clustering,
968 |
969 | )
--------------------------------------------------------------------------------