├── experiments ├── __init__.py ├── toy │ ├── __init__.py │ ├── README.md │ ├── utils.py │ ├── problem.py │ └── trainer.py ├── nyuv2 │ ├── __init__.py │ ├── dataset │ │ └── .gitinclude │ ├── README.md │ ├── utils.py │ ├── data.py │ ├── trainer.py │ └── models.py ├── quantum_chemistry │ ├── __init__.py │ ├── README.md │ ├── models.py │ ├── utils.py │ └── trainer.py └── utils.py ├── .isort.cfg ├── misc └── toy_pareto_2d.png ├── methods ├── __init__.py ├── min_norm_solvers.py └── weight_methods.py ├── requirements.txt ├── LICENSE ├── setup.py ├── .gitignore └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/toy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/nyuv2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/nyuv2/dataset/.gitinclude: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | multi_line_output = 3 4 | -------------------------------------------------------------------------------- /misc/toy_pareto_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AvivNavon/nash-mtl/HEAD/misc/toy_pareto_2d.png -------------------------------------------------------------------------------- /experiments/toy/README.md: -------------------------------------------------------------------------------- 1 | # Illustrative Example 2 | 3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad). -------------------------------------------------------------------------------- /experiments/quantum_chemistry/README.md: -------------------------------------------------------------------------------- 1 | # QM9 Experiment 2 | 3 | Modification of the example code in [PyG](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py). -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from methods.weight_methods import ( 2 | METHODS, 3 | MGDA, 4 | STL, 5 | LinearScalarization, 6 | NashMTL, 7 | PCGrad, 8 | Uncertainty, 9 | ) 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.2.1 2 | numpy>=1.18.2 3 | torch>=1.4.0 4 | torchvision>=0.8.0 5 | cvxpy==1.5.4 6 | tqdm>=4.45.0 7 | pandas 8 | scikit-learn 9 | seaborn 10 | wandb 11 | plotly 12 | -------------------------------------------------------------------------------- /experiments/nyuv2/README.md: -------------------------------------------------------------------------------- 1 | # NYUv2 Experiment 2 | 3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad) and [MTAN](https://github.com/lorenmt/mtan). 4 | 5 | ## Dataset 6 | 7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `.dataset/train` and `./dataset/val`. 8 | 9 | ## Evaluation 10 | 11 | To align with previous work on MTL [Liu et al. (2019)](https://arxiv.org/abs/1803.10704); [Yu et al. (2020)](https://arxiv.org/abs/2001.06782); [Liu et al. (2021)](https://arxiv.org/pdf/2110.14048.pdf) we report the test performance averaged 12 | over the last 10 epochs. Note that this averaging is not handled in the code and need to be applied by the user. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aviv Navon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | from os import path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | here = path.abspath(path.dirname(__file__)) 7 | 8 | # get the long description from the README.md file 9 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 10 | long_description = f.read() 11 | 12 | 13 | # get reqs 14 | def requirements(): 15 | list_requirements = [] 16 | with open("requirements.txt") as f: 17 | for line in f: 18 | list_requirements.append(line.rstrip()) 19 | return list_requirements 20 | 21 | 22 | setup( 23 | name="nashmtl", 24 | version="1.0.0", # Required 25 | description="Nash-MTL", # Optional 26 | long_description="", # Optional 27 | long_description_content_type="text/markdown", # Optional (see note above) 28 | url="", # Optional 29 | author="", # Optional 30 | author_email="", # Optional 31 | packages=find_packages(exclude=["contrib", "docs", "tests"]), 32 | # packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required 33 | python_requires=">=3.6", 34 | install_requires=requirements(), # Optional 35 | ) 36 | -------------------------------------------------------------------------------- /experiments/toy/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import seaborn as sns 4 | import torch 5 | from matplotlib import pyplot as plt 6 | 7 | from experiments.toy.problem import Toy 8 | 9 | 10 | # plotting utils 11 | def plot_2d_pareto(trajectories: dict, scale): 12 | """Adaptation of code from: https://github.com/Cranial-XIX/CAGrad""" 13 | fig, ax = plt.subplots(figsize=(6, 5)) 14 | 15 | F = Toy(scale=scale) 16 | 17 | losses = [] 18 | for res in trajectories.values(): 19 | losses.append(F.batch_forward(torch.from_numpy(res["traj"]))) 20 | 21 | yy = -8.3552 22 | x = np.linspace(-7, 7, 1000) 23 | 24 | inpt = np.stack((x, [yy] * len(x))).T 25 | Xs = torch.from_numpy(inpt).double() 26 | 27 | Ys = F.batch_forward(Xs) 28 | ax.plot( 29 | Ys.numpy()[:, 0], 30 | Ys.numpy()[:, 1], 31 | "-", 32 | linewidth=8, 33 | color="#72727A", 34 | label="Pareto Front", 35 | ) # Pareto front 36 | 37 | for i, tt in enumerate(losses): 38 | ax.scatter( 39 | tt[0, 0], 40 | tt[0, 1], 41 | color="k", 42 | s=150, 43 | zorder=10, 44 | label="Initial Point" if i == 0 else None, 45 | ) 46 | colors = matplotlib.cm.magma_r(np.linspace(0.1, 0.6, tt.shape[0])) 47 | ax.scatter(tt[:, 0], tt[:, 1], color=colors, s=5, zorder=9) 48 | 49 | sns.despine() 50 | ax.set_xlabel(r"$\ell_1$", size=30) 51 | ax.set_ylabel(r"$\ell_2$", size=30) 52 | ax.xaxis.set_label_coords(1.015, -0.03) 53 | ax.yaxis.set_label_coords(-0.01, 1.01) 54 | 55 | for tick in ax.xaxis.get_major_ticks(): 56 | tick.label.set_fontsize(20) 57 | for tick in ax.yaxis.get_major_ticks(): 58 | tick.label.set_fontsize(20) 59 | 60 | plt.tight_layout() 61 | 62 | legend = ax.legend( 63 | loc=2, bbox_to_anchor=(-0.15, 1.3), frameon=False, fontsize=20, ncol=2 64 | ) 65 | 66 | return ax, fig, legend 67 | -------------------------------------------------------------------------------- /experiments/toy/problem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | LOWER = 0.000005 5 | 6 | 7 | class Toy(nn.Module): 8 | def __init__(self, scale=1.0, scale_both_losses=1.0): 9 | super(Toy, self).__init__() 10 | self.centers = torch.Tensor([[-3.0, 0], [3.0, 0]]) 11 | self.scale = scale 12 | self.scale_both_losses = scale_both_losses 13 | 14 | def forward(self, x, compute_grad=False): 15 | x1 = x[0] 16 | x2 = x[1] 17 | 18 | f1 = torch.clamp((0.5 * (-x1 - 7) - torch.tanh(-x2)).abs(), LOWER).log() + 6 19 | f2 = torch.clamp((0.5 * (-x1 + 3) + torch.tanh(-x2) + 2).abs(), LOWER).log() + 6 20 | c1 = torch.clamp(torch.tanh(x2 * 0.5), 0) 21 | 22 | f1_sq = ((-x1 + 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20 23 | f2_sq = ((-x1 - 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20 24 | c2 = torch.clamp(torch.tanh(-x2 * 0.5), 0) 25 | 26 | f1 = f1 * c1 + f1_sq * c2 27 | f1 *= self.scale 28 | f2 = f2 * c1 + f2_sq * c2 29 | 30 | f = torch.stack([f1, f2]) * self.scale_both_losses 31 | if compute_grad: 32 | g11 = torch.autograd.grad(f1, x1, retain_graph=True)[0].item() 33 | g12 = torch.autograd.grad(f1, x2, retain_graph=True)[0].item() 34 | g21 = torch.autograd.grad(f2, x1, retain_graph=True)[0].item() 35 | g22 = torch.autograd.grad(f2, x2, retain_graph=True)[0].item() 36 | g = torch.Tensor([[g11, g21], [g12, g22]]) 37 | return f, g 38 | else: 39 | return f 40 | 41 | def batch_forward(self, x): 42 | x1 = x[:, 0] 43 | x2 = x[:, 1] 44 | 45 | f1 = torch.clamp((0.5 * (-x1 - 7) - torch.tanh(-x2)).abs(), LOWER).log() + 6 46 | f2 = torch.clamp((0.5 * (-x1 + 3) + torch.tanh(-x2) + 2).abs(), LOWER).log() + 6 47 | c1 = torch.clamp(torch.tanh(x2 * 0.5), 0) 48 | 49 | f1_sq = ((-x1 + 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20 50 | f2_sq = ((-x1 - 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20 51 | c2 = torch.clamp(torch.tanh(-x2 * 0.5), 0) 52 | 53 | f1 = f1 * c1 + f1_sq * c2 54 | f1 *= self.scale 55 | f2 = f2 * c1 + f2_sq * c2 56 | 57 | f = torch.cat([f1.view(-1, 1), f2.view(-1, 1)], -1) * self.scale_both_losses 58 | return f 59 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/models.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Iterator 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import GRU, Linear, ReLU, Sequential 7 | from torch_geometric.nn import DimeNet, NNConv, Set2Set, radius_graph 8 | 9 | 10 | class Net(torch.nn.Module): 11 | def __init__(self, n_tasks, num_features=11, dim=64): 12 | super().__init__() 13 | self.n_tasks = n_tasks 14 | self.dim = dim 15 | self.lin0 = torch.nn.Linear(num_features, dim) 16 | 17 | nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim)) 18 | self.conv = NNConv(dim, dim, nn, aggr="mean") 19 | self.gru = GRU(dim, dim) 20 | 21 | self.set2set = Set2Set(dim, processing_steps=3) 22 | self.lin1 = torch.nn.Linear(2 * dim, dim) 23 | 24 | self._init_task_heads() 25 | 26 | def _init_task_heads(self): 27 | for i in range(self.n_tasks): 28 | setattr(self, f"head_{i}", torch.nn.Linear(self.dim, 1)) 29 | self.task_specific = torch.nn.ModuleList( 30 | [getattr(self, f"head_{i}") for i in range(self.n_tasks)] 31 | ) 32 | 33 | def forward(self, data, return_representation=False): 34 | out = F.relu(self.lin0(data.x)) 35 | h = out.unsqueeze(0) 36 | 37 | for i in range(3): 38 | m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) 39 | out, h = self.gru(m.unsqueeze(0), h) 40 | out = out.squeeze(0) 41 | 42 | out = self.set2set(out, data.batch) 43 | features = F.relu(self.lin1(out)) 44 | logits = torch.cat( 45 | [getattr(self, f"head_{i}")(features) for i in range(self.n_tasks)], dim=1 46 | ) 47 | if return_representation: 48 | return logits, features 49 | return logits 50 | 51 | def shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 52 | return chain( 53 | self.lin0.parameters(), 54 | self.conv.parameters(), 55 | self.gru.parameters(), 56 | self.set2set.parameters(), 57 | self.lin1.parameters(), 58 | ) 59 | 60 | def task_specific_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 61 | return self.task_specific.parameters() 62 | 63 | def last_shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 64 | return self.lin1.parameters() 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .DS_Store 131 | .idea/ 132 | wandb/ 133 | datasets/ 134 | outputs/ 135 | data/ -------------------------------------------------------------------------------- /experiments/quantum_chemistry/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.utils import remove_self_loops 4 | 5 | 6 | class MyTransform(object): 7 | def __init__(self, target: list = None): 8 | if target is None: 9 | target = torch.tensor([0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]) # removing 4 10 | else: 11 | target = torch.tensor(target) 12 | self.target = target 13 | 14 | def __call__(self, data): 15 | # Specify target. 16 | data.y = data.y[:, self.target] 17 | return data 18 | 19 | 20 | class Complete(object): 21 | def __call__(self, data): 22 | device = data.edge_index.device 23 | 24 | row = torch.arange(data.num_nodes, dtype=torch.long, device=device) 25 | col = torch.arange(data.num_nodes, dtype=torch.long, device=device) 26 | 27 | row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) 28 | col = col.repeat(data.num_nodes) 29 | edge_index = torch.stack([row, col], dim=0) 30 | 31 | edge_attr = None 32 | if data.edge_attr is not None: 33 | idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] 34 | size = list(data.edge_attr.size()) 35 | size[0] = data.num_nodes * data.num_nodes 36 | edge_attr = data.edge_attr.new_zeros(size) 37 | edge_attr[idx] = data.edge_attr 38 | 39 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 40 | data.edge_attr = edge_attr 41 | data.edge_index = edge_index 42 | 43 | return data 44 | 45 | 46 | qm9_target_dict = { 47 | 0: "mu", 48 | 1: "alpha", 49 | 2: "homo", 50 | 3: "lumo", 51 | 5: "r2", 52 | 6: "zpve", 53 | 7: "U0", 54 | 8: "U", 55 | 9: "H", 56 | 10: "G", 57 | 11: "Cv", 58 | } 59 | 60 | # for \Delta_m calculations 61 | # ------------------------- 62 | # DimeNet uses the atomization energy for targets U0, U, H, and G. 63 | target_idx = [0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11] 64 | 65 | # Report meV instead of eV. 66 | multiply_indx = [2, 3, 5, 6, 7, 8, 9] 67 | 68 | n_tasks = len(target_idx) 69 | 70 | # stl results 71 | BASE = np.array( 72 | [ 73 | 0.0671, 74 | 0.1814, 75 | 60.576, 76 | 53.915, 77 | 0.5027, 78 | 4.539, 79 | 58.838, 80 | 64.244, 81 | 63.852, 82 | 66.223, 83 | 0.07212, 84 | ] 85 | ) 86 | 87 | SIGN = np.array([0] * n_tasks) 88 | KK = np.ones(n_tasks) * -1 89 | 90 | 91 | def delta_fn(a): 92 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # *100 for percentage 93 | -------------------------------------------------------------------------------- /experiments/nyuv2/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ConfMatrix(object): 6 | def __init__(self, num_classes): 7 | self.num_classes = num_classes 8 | self.mat = None 9 | 10 | def update(self, pred, target): 11 | n = self.num_classes 12 | if self.mat is None: 13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 14 | with torch.no_grad(): 15 | k = (target >= 0) & (target < n) 16 | inds = n * target[k].to(torch.int64) + pred[k] 17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 18 | 19 | def get_metrics(self): 20 | h = self.mat.float() 21 | acc = torch.diag(h).sum() / h.sum() 22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy() 24 | 25 | 26 | def depth_error(x_pred, x_output): 27 | device = x_pred.device 28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device) 29 | x_pred_true = x_pred.masked_select(binary_mask) 30 | x_output_true = x_output.masked_select(binary_mask) 31 | abs_err = torch.abs(x_pred_true - x_output_true) 32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true 33 | return ( 34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 35 | ).item(), ( 36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 37 | ).item() 38 | 39 | 40 | def normal_error(x_pred, x_output): 41 | binary_mask = torch.sum(x_output, dim=1) != 0 42 | error = ( 43 | torch.acos( 44 | torch.clamp( 45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1 46 | ) 47 | ) 48 | .detach() 49 | .cpu() 50 | .numpy() 51 | ) 52 | error = np.degrees(error) 53 | return ( 54 | np.mean(error), 55 | np.median(error), 56 | np.mean(error < 11.25), 57 | np.mean(error < 22.5), 58 | np.mean(error < 30), 59 | ) 60 | 61 | 62 | # for calculating \Delta_m 63 | delta_stats = [ 64 | "mean iou", 65 | "pix acc", 66 | "abs err", 67 | "rel err", 68 | "mean", 69 | "median", 70 | "<11.25", 71 | "<22.5", 72 | "<30", 73 | ] 74 | BASE = np.array( 75 | [0.3830, 0.6376, 0.6754, 0.2780, 25.01, 19.21, 0.3014, 0.5720, 0.6915] 76 | ) # base results from CAGrad 77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1]) 78 | KK = np.ones(9) * -1 79 | 80 | 81 | def delta_fn(a): 82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nash-MTL 2 | 3 | Official implementation of _"Multi-Task Learning as a Bargaining Game"_. 4 | 5 |

6 | 7 |

8 | 9 | ## Setup environment 10 | 11 | ```bash 12 | conda create -n nashmtl python=3.9.7 13 | conda activate nashmtl 14 | conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=10.2 -c pytorch 15 | conda install pyg -c pyg -c conda-forge 16 | ``` 17 | 18 | Install the repo: 19 | 20 | ```bash 21 | git clone https://github.com/AvivNavon/nash-mtl.git 22 | cd nash-mtl 23 | pip install -e . 24 | ``` 25 | 26 | ## Run experiment 27 | 28 | To run experiments: 29 | 30 | ```bash 31 | cd experiment/ 32 | python trainer.py --method=nashmtl 33 | ``` 34 | Follow instruction on the experiment README file for more information regarding, e.g., datasets. 35 | 36 | Here `` is one of `[toy, quantum_chemistry, nyuv2]`. You can also replace `nashmtl` with on of the following MTL methods. 37 | 38 | We also support experiment tracking with **[Weights & Biases](https://wandb.ai/site)** with two additional parameters: 39 | 40 | ```bash 41 | python trainer.py --method=nashmtl --wandb_project= --wandb_entity= 42 | ``` 43 | 44 | ## MTL methods 45 | 46 | We support the following MTL methods with a unified API. To run experiment with MTL method `X` simply run: 47 | ```bash 48 | python trainer.py --method=X 49 | ``` 50 | 51 | | Method (code name) | Paper (notes) | 52 | | :---: | :---: | 53 | | Nash-MTL (`nashmtl`) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) | 54 | | CAGrad (`cagrad`) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) | 55 | | PCGrad (`pcgrad`) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/abs/2001.06782) | 56 | | IMTL-G (`imtl`) | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) | 57 | | MGDA (`mgda`) | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/abs/1810.04650) | 58 | | DWA (`dwa`) | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/abs/1803.10704) | 59 | | Uncertainty weighting (`uw`) | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) | 60 | | Linear scalarization (`ls`) | - (equal weighting) | 61 | | Scale-invariant baseline (`scaleinvls`) | - (see Nash-MTL paper for details) | 62 | | Random Loss Weighting (`rlw`) | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) | 63 | 64 | 65 | 66 | 67 | ## Citation 68 | 69 | If you find `Nash-MTL` to be useful in your own research, please consider citing the following paper: 70 | 71 | ```bib 72 | @article{navon2022multi, 73 | title={Multi-Task Learning as a Bargaining Game}, 74 | author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan}, 75 | journal={arXiv preprint arXiv:2202.01017}, 76 | year={2022} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from methods import METHODS 11 | 12 | 13 | def str_to_list(string): 14 | return [float(s) for s in string.split(",")] 15 | 16 | 17 | def str_or_float(value): 18 | try: 19 | return float(value) 20 | except: 21 | return value 22 | 23 | 24 | def str2bool(v): 25 | if isinstance(v, bool): 26 | return v 27 | if v.lower() in ("yes", "true", "t", "y", "1"): 28 | return True 29 | elif v.lower() in ("no", "false", "f", "n", "0"): 30 | return False 31 | else: 32 | raise argparse.ArgumentTypeError("Boolean value expected.") 33 | 34 | 35 | common_parser = argparse.ArgumentParser(add_help=False) 36 | common_parser.add_argument("--data-path", type=Path, help="path to data") 37 | common_parser.add_argument("--n-epochs", type=int, default=300) 38 | common_parser.add_argument("--batch-size", type=int, default=120, help="batch size") 39 | common_parser.add_argument( 40 | "--method", type=str, choices=list(METHODS.keys()), help="MTL weight method" 41 | ) 42 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 43 | common_parser.add_argument( 44 | "--method-params-lr", 45 | type=float, 46 | default=0.025, 47 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting", 48 | ) 49 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID") 50 | common_parser.add_argument("--seed", type=int, default=42, help="seed value") 51 | # NashMTL 52 | common_parser.add_argument( 53 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations" 54 | ) 55 | common_parser.add_argument( 56 | "--update-weights-every", 57 | type=int, 58 | default=1, 59 | help="update task weights every x iterations.", 60 | ) 61 | # stl 62 | common_parser.add_argument( 63 | "--main-task", 64 | type=int, 65 | default=0, 66 | help="main task for stl. Ignored if method != stl", 67 | ) 68 | # cagrad 69 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.") 70 | # dwa 71 | # dwa 72 | common_parser.add_argument( 73 | "--dwa-temp", 74 | type=float, 75 | default=2.0, 76 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.", 77 | ) 78 | 79 | 80 | def count_parameters(model): 81 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 82 | 83 | 84 | def set_logger(): 85 | logging.basicConfig( 86 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 87 | level=logging.INFO, 88 | ) 89 | 90 | 91 | def set_seed(seed): 92 | """for reproducibility 93 | :param seed: 94 | :return: 95 | """ 96 | np.random.seed(seed) 97 | random.seed(seed) 98 | 99 | torch.manual_seed(seed) 100 | if torch.cuda.is_available(): 101 | torch.cuda.manual_seed(seed) 102 | torch.cuda.manual_seed_all(seed) 103 | 104 | torch.backends.cudnn.enabled = True 105 | torch.backends.cudnn.benchmark = False 106 | torch.backends.cudnn.deterministic = True 107 | 108 | 109 | def get_device(no_cuda=False, gpus="0"): 110 | return torch.device( 111 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu" 112 | ) 113 | 114 | 115 | def extract_weight_method_parameters_from_args(args): 116 | weight_methods_parameters = defaultdict(dict) 117 | weight_methods_parameters.update( 118 | dict( 119 | nashmtl=dict( 120 | update_weights_every=args.update_weights_every, 121 | optim_niter=args.nashmtl_optim_niter, 122 | ), 123 | stl=dict(main_task=args.main_task), 124 | cagrad=dict(c=args.c), 125 | dwa=dict(temp=args.dwa_temp), 126 | ) 127 | ) 128 | return weight_methods_parameters 129 | -------------------------------------------------------------------------------- /experiments/toy/trainer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import logging 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from experiments.toy.problem import Toy 12 | from experiments.toy.utils import plot_2d_pareto 13 | from experiments.utils import ( 14 | common_parser, 15 | extract_weight_method_parameters_from_args, 16 | set_logger, 17 | ) 18 | from methods.weight_methods import WeightMethods 19 | 20 | set_logger() 21 | 22 | 23 | def main(method_type, device, n_iter, scale): 24 | weight_methods_parameters = extract_weight_method_parameters_from_args(args) 25 | n_tasks = 2 26 | 27 | F = Toy(scale=scale) 28 | 29 | all_traj = dict() 30 | 31 | # the initial positions 32 | inits = [ 33 | torch.Tensor([-8.5, 7.5]), 34 | torch.Tensor([0.0, 0.0]), 35 | torch.Tensor([9.0, 9.0]), 36 | torch.Tensor([-7.5, -0.5]), 37 | torch.Tensor([9, -1.0]), 38 | ] 39 | 40 | for i, init in enumerate(inits): 41 | traj = [] 42 | x = init.clone() 43 | x.requires_grad = True 44 | x = x.to(device) 45 | 46 | method = WeightMethods( 47 | method=method_type, 48 | device=device, 49 | n_tasks=n_tasks, 50 | **weight_methods_parameters[method_type], 51 | ) 52 | 53 | optimizer = torch.optim.Adam( 54 | [ 55 | dict(params=[x], lr=1e-3), 56 | dict(params=method.parameters(), lr=args.method_params_lr), 57 | ], 58 | ) 59 | 60 | for _ in tqdm(range(n_iter)): 61 | traj.append(x.cpu().detach().numpy().copy()) 62 | 63 | optimizer.zero_grad() 64 | f = F(x, False) 65 | _ = method.backward( 66 | losses=f, 67 | shared_parameters=(x,), 68 | task_specific_parameters=None, 69 | last_shared_parameters=None, 70 | representation=None, 71 | ) 72 | optimizer.step() 73 | 74 | all_traj[i] = dict(init=init.cpu().detach().numpy().copy(), traj=np.array(traj)) 75 | 76 | return all_traj 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = ArgumentParser( 81 | "Toy example (modification of the one in CAGrad)", parents=[common_parser] 82 | ) 83 | parser.set_defaults(n_epochs=35000, method="nashmtl", data_path=None) 84 | parser.add_argument( 85 | "--scale", default=1e-1, type=float, help="scale for first loss" 86 | ) 87 | parser.add_argument("--out-path", default="outputs", type=Path, help="output path") 88 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.") 89 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.") 90 | args = parser.parse_args() 91 | 92 | if args.wandb_project is not None: 93 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args) 94 | 95 | out_path = args.out_path 96 | out_path.mkdir(parents=True, exist_ok=True) 97 | logging.info(f"Logs and plots are saved in: {out_path.as_posix()}") 98 | 99 | device = torch.device("cpu") 100 | all_traj = main( 101 | method_type=args.method, device=device, n_iter=args.n_epochs, scale=args.scale 102 | ) 103 | 104 | # plot 105 | ax, fig, legend = plot_2d_pareto(trajectories=all_traj, scale=args.scale) 106 | 107 | title_map = { 108 | "nashmtl": "Nash-MTL", 109 | "cagrad": "CAGrad", 110 | "mgda": "MGDA", 111 | "pcgrad": "PCGrad", 112 | "ls": "LS", 113 | } 114 | ax.set_title(title_map[args.method], fontsize=25) 115 | plt.savefig( 116 | out_path / f"{args.method}.png", 117 | bbox_extra_artists=(legend,), 118 | bbox_inches="tight", 119 | facecolor="white", 120 | ) 121 | plt.close() 122 | 123 | if wandb.run is not None: 124 | wandb.log({"Pareto Front": wandb.Image((out_path / f"{args.method}.png").as_posix())}) 125 | 126 | wandb.finish() 127 | -------------------------------------------------------------------------------- /experiments/nyuv2/data.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataset import Dataset 9 | 10 | """Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py 11 | 12 | """ 13 | 14 | 15 | class RandomScaleCrop(object): 16 | """ 17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34. 18 | """ 19 | 20 | def __init__(self, scale=[1.0, 1.2, 1.5]): 21 | self.scale = scale 22 | 23 | def __call__(self, img, label, depth, normal): 24 | height, width = img.shape[-2:] 25 | sc = self.scale[random.randint(0, len(self.scale) - 1)] 26 | h, w = int(height / sc), int(width / sc) 27 | i = random.randint(0, height - h) 28 | j = random.randint(0, width - w) 29 | img_ = F.interpolate( 30 | img[None, :, i : i + h, j : j + w], 31 | size=(height, width), 32 | mode="bilinear", 33 | align_corners=True, 34 | ).squeeze(0) 35 | label_ = ( 36 | F.interpolate( 37 | label[None, None, i : i + h, j : j + w], 38 | size=(height, width), 39 | mode="nearest", 40 | ) 41 | .squeeze(0) 42 | .squeeze(0) 43 | ) 44 | depth_ = F.interpolate( 45 | depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest" 46 | ).squeeze(0) 47 | normal_ = F.interpolate( 48 | normal[None, :, i : i + h, j : j + w], 49 | size=(height, width), 50 | mode="bilinear", 51 | align_corners=True, 52 | ).squeeze(0) 53 | return img_, label_, depth_ / sc, normal_ 54 | 55 | 56 | class NYUv2(Dataset): 57 | """ 58 | We could further improve the performance with the data augmentation of NYUv2 defined in: 59 | [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing 60 | [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation 61 | [3] Mti-net: Multiscale task interaction networks for multi-task learning 62 | 1. Random scale in a selected raio 1.0, 1.2, and 1.5. 63 | 2. Random horizontal flip. 64 | Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper. 65 | """ 66 | 67 | def __init__(self, root, train=True, augmentation=False): 68 | self.train = train 69 | self.root = os.path.expanduser(root) 70 | self.augmentation = augmentation 71 | 72 | # read the data file 73 | if train: 74 | self.data_path = root + "/train" 75 | else: 76 | self.data_path = root + "/val" 77 | 78 | # calculate data length 79 | self.data_len = len( 80 | fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy") 81 | ) 82 | 83 | def __getitem__(self, index): 84 | # load data from the pre-processed npy files 85 | image = torch.from_numpy( 86 | np.moveaxis( 87 | np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0 88 | ) 89 | ) 90 | semantic = torch.from_numpy( 91 | np.load(self.data_path + "/label/{:d}.npy".format(index)) 92 | ) 93 | depth = torch.from_numpy( 94 | np.moveaxis( 95 | np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0 96 | ) 97 | ) 98 | normal = torch.from_numpy( 99 | np.moveaxis( 100 | np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0 101 | ) 102 | ) 103 | 104 | # apply data augmentation if required 105 | if self.augmentation: 106 | image, semantic, depth, normal = RandomScaleCrop()( 107 | image, semantic, depth, normal 108 | ) 109 | if torch.rand(1) < 0.5: 110 | image = torch.flip(image, dims=[2]) 111 | semantic = torch.flip(semantic, dims=[1]) 112 | depth = torch.flip(depth, dims=[2]) 113 | normal = torch.flip(normal, dims=[2]) 114 | normal[0, :, :] = -normal[0, :, :] 115 | 116 | return image.float(), semantic.float(), depth.float(), normal.float() 117 | 118 | def __len__(self): 119 | return self.data_len 120 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/trainer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import torch_geometric.transforms as T 7 | from torch_geometric.datasets import QM9 8 | from torch_geometric.loader import DataLoader 9 | from tqdm import trange 10 | import wandb 11 | 12 | from experiments.quantum_chemistry.models import Net 13 | from experiments.quantum_chemistry.utils import ( 14 | Complete, 15 | MyTransform, 16 | delta_fn, 17 | multiply_indx, 18 | ) 19 | from experiments.quantum_chemistry.utils import target_idx as targets 20 | from experiments.utils import ( 21 | common_parser, 22 | extract_weight_method_parameters_from_args, 23 | get_device, 24 | set_logger, 25 | set_seed, 26 | str2bool, 27 | ) 28 | from methods.weight_methods import WeightMethods 29 | 30 | set_logger() 31 | 32 | 33 | @torch.no_grad() 34 | def evaluate(model, loader, std, scale_target): 35 | model.eval() 36 | data_size = 0.0 37 | task_losses = 0.0 38 | for i, data in enumerate(loader): 39 | data = data.to(device) 40 | out = model(data) 41 | if scale_target: 42 | task_losses += F.l1_loss( 43 | out * std.to(device), data.y * std.to(device), reduction="none" 44 | ).sum( 45 | 0 46 | ) # MAE 47 | else: 48 | task_losses += F.l1_loss(out, data.y, reduction="none").sum(0) # MAE 49 | data_size += len(data.y) 50 | 51 | model.train() 52 | 53 | avg_task_losses = task_losses / data_size 54 | 55 | # Report meV instead of eV. 56 | avg_task_losses = avg_task_losses.detach().cpu().numpy() 57 | avg_task_losses[multiply_indx] *= 1000 58 | 59 | delta_m = delta_fn(avg_task_losses) 60 | return dict( 61 | avg_loss=avg_task_losses.mean(), 62 | avg_task_losses=avg_task_losses, 63 | delta_m=delta_m, 64 | ) 65 | 66 | 67 | def main( 68 | data_path: str, 69 | batch_size: int, 70 | device: torch.device, 71 | method: str, 72 | weight_method_params: dict, 73 | lr: float, 74 | method_params_lr: float, 75 | n_epochs: int, 76 | targets: list = None, 77 | scale_target: bool = True, 78 | main_task: int = None, 79 | ): 80 | dim = 64 81 | model = Net(n_tasks=len(targets), num_features=11, dim=dim).to(device) 82 | 83 | transform = T.Compose([MyTransform(targets), Complete(), T.Distance(norm=False)]) 84 | dataset = QM9(data_path, transform=transform).shuffle() 85 | 86 | # Split datasets. 87 | test_dataset = dataset[:10000] 88 | val_dataset = dataset[10000:20000] 89 | train_dataset = dataset[20000:] 90 | 91 | std = None 92 | if scale_target: 93 | mean = train_dataset.data.y[:, targets].mean(dim=0, keepdim=True) 94 | std = train_dataset.data.y[:, targets].std(dim=0, keepdim=True) 95 | 96 | dataset.data.y[:, targets] = (dataset.data.y[:, targets] - mean) / std 97 | 98 | test_loader = DataLoader( 99 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=0 100 | ) 101 | val_loader = DataLoader( 102 | val_dataset, batch_size=batch_size, shuffle=False, num_workers=0 103 | ) 104 | train_loader = DataLoader( 105 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=0 106 | ) 107 | 108 | weight_method = WeightMethods( 109 | method, 110 | n_tasks=len(targets), 111 | device=device, 112 | **weight_method_params[method], 113 | ) 114 | 115 | optimizer = torch.optim.Adam( 116 | [ 117 | dict(params=model.parameters(), lr=lr), 118 | dict(params=weight_method.parameters(), lr=method_params_lr), 119 | ], 120 | ) 121 | 122 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 123 | optimizer, mode="min", factor=0.7, patience=5, min_lr=0.00001 124 | ) 125 | 126 | epoch_iterator = trange(n_epochs) 127 | 128 | best_val = np.inf 129 | best_test = np.inf 130 | best_test_delta = np.inf 131 | best_val_delta = np.inf 132 | best_test_results = None 133 | 134 | for epoch in epoch_iterator: 135 | lr = scheduler.optimizer.param_groups[0]["lr"] 136 | for j, data in enumerate(train_loader): 137 | model.train() 138 | 139 | data = data.to(device) 140 | optimizer.zero_grad() 141 | 142 | out, features = model(data, return_representation=True) 143 | 144 | losses = F.mse_loss(out, data.y, reduction="none").mean(0) 145 | 146 | loss, extra_outputs = weight_method.backward( 147 | losses=losses, 148 | shared_parameters=list(model.shared_parameters()), 149 | task_specific_parameters=list(model.task_specific_parameters()), 150 | last_shared_parameters=list(model.last_shared_parameters()), 151 | representation=features, 152 | ) 153 | 154 | optimizer.step() 155 | 156 | 157 | 158 | val_loss_dict = evaluate(model, val_loader, std=std, scale_target=scale_target) 159 | test_loss_dict = evaluate( 160 | model, test_loader, std=std, scale_target=scale_target 161 | ) 162 | val_loss = val_loss_dict["avg_loss"] 163 | val_delta = val_loss_dict["delta_m"] 164 | test_loss = test_loss_dict["avg_loss"] 165 | test_delta = test_loss_dict["delta_m"] 166 | 167 | if method == "stl": 168 | best_val_criteria = val_loss_dict["avg_task_losses"][main_task] <= best_val 169 | else: 170 | best_val_criteria = val_delta <= best_val_delta 171 | 172 | if best_val_criteria: 173 | best_val = val_loss 174 | best_test = test_loss 175 | best_test_results = test_loss_dict 176 | best_val_delta = val_delta 177 | best_test_delta = test_delta 178 | 179 | # for logger 180 | epoch_iterator.set_description( 181 | f"epoch {epoch} | lr={lr} | train loss {losses.mean().item():.3f} | val loss: {val_loss:.3f} | " 182 | f"test loss: {test_loss:.3f} | best test loss {best_test:.3f} | best_test_delta {best_test_delta:.3f}" 183 | ) 184 | 185 | if wandb.run is not None: 186 | wandb.log({"Learning Rate": lr}, step=epoch) 187 | wandb.log({"Train Loss": losses.mean().item()}, step=epoch) 188 | wandb.log({"Val Loss": val_loss}, step=epoch) 189 | wandb.log({"Val Delta": val_delta}, step=epoch) 190 | wandb.log({"Test Loss": test_loss}, step=epoch) 191 | wandb.log({"Test Delta": test_delta}, step=epoch) 192 | wandb.log({"Best Test Loss": best_test}, step=epoch) 193 | wandb.log({"Best Test Delta": best_test_delta}, step=epoch) 194 | 195 | scheduler.step( 196 | val_loss_dict["avg_task_losses"][main_task] 197 | if method == "stl" 198 | else val_delta 199 | ) 200 | 201 | 202 | if __name__ == "__main__": 203 | parser = ArgumentParser("QM9", parents=[common_parser]) 204 | parser.set_defaults( 205 | data_path="dataset", 206 | lr=1e-3, 207 | n_epochs=300, 208 | batch_size=120, 209 | method="nashmtl", 210 | ) 211 | parser.add_argument("--scale-y", default=True, type=str2bool) 212 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.") 213 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.") 214 | args = parser.parse_args() 215 | 216 | # set seed 217 | set_seed(args.seed) 218 | 219 | if args.wandb_project is not None: 220 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args) 221 | 222 | weight_method_params = extract_weight_method_parameters_from_args(args) 223 | 224 | device = get_device(gpus=args.gpu) 225 | main( 226 | data_path=args.data_path, 227 | batch_size=args.batch_size, 228 | device=device, 229 | method=args.method, 230 | weight_method_params=weight_method_params, 231 | lr=args.lr, 232 | method_params_lr=args.method_params_lr, 233 | n_epochs=args.n_epochs, 234 | targets=targets, 235 | scale_target=args.scale_y, 236 | main_task=args.main_task, 237 | ) 238 | 239 | if wandb.run is not None: 240 | wandb.finish() 241 | -------------------------------------------------------------------------------- /methods/min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | # This code is from 6 | # Multi-Task Learning as Multi-Objective Optimization 7 | # Ozan Sener, Vladlen Koltun 8 | # Neural Information Processing Systems (NeurIPS) 2018 9 | # https://github.com/intel-isl/MultiObjectiveOptimization 10 | class MinNormSolver: 11 | MAX_ITER = 250 12 | STOP_CRIT = 1e-5 13 | 14 | @staticmethod 15 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 16 | """ 17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 18 | d is the distance (objective) optimzed 19 | v1v1 = 20 | v1v2 = 21 | v2v2 = 22 | """ 23 | if v1v2 >= v1v1: 24 | # Case: Fig 1, third column 25 | gamma = 0.999 26 | cost = v1v1 27 | return gamma, cost 28 | if v1v2 >= v2v2: 29 | # Case: Fig 1, first column 30 | gamma = 0.001 31 | cost = v2v2 32 | return gamma, cost 33 | # Case: Fig 1, second column 34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2)) 35 | cost = v2v2 + gamma * (v1v2 - v2v2) 36 | return gamma, cost 37 | 38 | @staticmethod 39 | def _min_norm_2d(vecs, dps): 40 | """ 41 | Find the minimum norm solution as combination of two points 42 | This is correct only in 2D 43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j 44 | """ 45 | dmin = 1e8 46 | for i in range(len(vecs)): 47 | for j in range(i + 1, len(vecs)): 48 | if (i, j) not in dps: 49 | dps[(i, j)] = 0.0 50 | for k in range(len(vecs[i])): 51 | dps[(i, j)] += torch.dot( 52 | vecs[i][k], vecs[j][k] 53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0] 54 | dps[(j, i)] = dps[(i, j)] 55 | if (i, i) not in dps: 56 | dps[(i, i)] = 0.0 57 | for k in range(len(vecs[i])): 58 | dps[(i, i)] += torch.dot( 59 | vecs[i][k], vecs[i][k] 60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0] 61 | if (j, j) not in dps: 62 | dps[(j, j)] = 0.0 63 | for k in range(len(vecs[i])): 64 | dps[(j, j)] += torch.dot( 65 | vecs[j][k], vecs[j][k] 66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0] 67 | c, d = MinNormSolver._min_norm_element_from2( 68 | dps[(i, i)], dps[(i, j)], dps[(j, j)] 69 | ) 70 | if d < dmin: 71 | dmin = d 72 | sol = [(i, j), c, d] 73 | return sol, dps 74 | 75 | @staticmethod 76 | def _projection2simplex(y): 77 | """ 78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i 79 | """ 80 | m = len(y) 81 | sorted_y = np.flip(np.sort(y), axis=0) 82 | tmpsum = 0.0 83 | tmax_f = (np.sum(y) - 1.0) / m 84 | for i in range(m - 1): 85 | tmpsum += sorted_y[i] 86 | tmax = (tmpsum - 1) / (i + 1.0) 87 | if tmax > sorted_y[i + 1]: 88 | tmax_f = tmax 89 | break 90 | return np.maximum(y - tmax_f, np.zeros(y.shape)) 91 | 92 | @staticmethod 93 | def _next_point(cur_val, grad, n): 94 | proj_grad = grad - (np.sum(grad) / n) 95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0] 96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0]) 97 | 98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7) 99 | t = 1 100 | if len(tm1[tm1 > 1e-7]) > 0: 101 | t = np.min(tm1[tm1 > 1e-7]) 102 | if len(tm2[tm2 > 1e-7]) > 0: 103 | t = min(t, np.min(tm2[tm2 > 1e-7])) 104 | 105 | next_point = proj_grad * t + cur_val 106 | next_point = MinNormSolver._projection2simplex(next_point) 107 | return next_point 108 | 109 | @staticmethod 110 | def find_min_norm_element(vecs): 111 | """ 112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence 116 | """ 117 | # Solution lying at the combination of two points 118 | dps = {} 119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 120 | 121 | n = len(vecs) 122 | sol_vec = np.zeros(n) 123 | sol_vec[init_sol[0][0]] = init_sol[1] 124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 125 | 126 | if n < 3: 127 | # This is optimal for n=2, so return the solution 128 | return sol_vec, init_sol[2] 129 | 130 | iter_count = 0 131 | 132 | grad_mat = np.zeros((n, n)) 133 | for i in range(n): 134 | for j in range(n): 135 | grad_mat[i, j] = dps[(i, j)] 136 | 137 | while iter_count < MinNormSolver.MAX_ITER: 138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec) 139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n) 140 | # Re-compute the inner products for line search 141 | v1v1 = 0.0 142 | v1v2 = 0.0 143 | v2v2 = 0.0 144 | for i in range(n): 145 | for j in range(n): 146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)] 147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)] 148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)] 149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point 151 | change = new_sol_vec - sol_vec 152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 153 | return sol_vec, nd 154 | sol_vec = new_sol_vec 155 | 156 | @staticmethod 157 | def find_min_norm_element_FW(vecs): 158 | """ 159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence 163 | """ 164 | # Solution lying at the combination of two points 165 | dps = {} 166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 167 | 168 | n = len(vecs) 169 | sol_vec = np.zeros(n) 170 | sol_vec[init_sol[0][0]] = init_sol[1] 171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 172 | 173 | if n < 3: 174 | # This is optimal for n=2, so return the solution 175 | return sol_vec, init_sol[2] 176 | 177 | iter_count = 0 178 | 179 | grad_mat = np.zeros((n, n)) 180 | for i in range(n): 181 | for j in range(n): 182 | grad_mat[i, j] = dps[(i, j)] 183 | 184 | while iter_count < MinNormSolver.MAX_ITER: 185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec)) 186 | 187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec)) 188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter]) 189 | v2v2 = grad_mat[t_iter, t_iter] 190 | 191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 192 | new_sol_vec = nc * sol_vec 193 | new_sol_vec[t_iter] += 1 - nc 194 | 195 | change = new_sol_vec - sol_vec 196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 197 | return sol_vec, nd 198 | sol_vec = new_sol_vec 199 | 200 | 201 | def gradient_normalizers(grads, losses, normalization_type): 202 | gn = {} 203 | if normalization_type == "norm": 204 | for t in grads: 205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])) 206 | elif normalization_type == "loss": 207 | for t in grads: 208 | gn[t] = losses[t] 209 | elif normalization_type == "loss+": 210 | for t in grads: 211 | gn[t] = losses[t] * np.sqrt( 212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]) 213 | ) 214 | elif normalization_type == "none": 215 | for t in grads: 216 | gn[t] = 1.0 217 | else: 218 | print("ERROR: Invalid Normalization Type") 219 | return gn 220 | -------------------------------------------------------------------------------- /experiments/nyuv2/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import wandb 3 | from argparse import ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from tqdm import trange 10 | 11 | from experiments.nyuv2.data import NYUv2 12 | from experiments.nyuv2.models import SegNet, SegNetMtan 13 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error 14 | from experiments.utils import ( 15 | common_parser, 16 | extract_weight_method_parameters_from_args, 17 | get_device, 18 | set_logger, 19 | set_seed, 20 | str2bool, 21 | ) 22 | from methods.weight_methods import WeightMethods 23 | 24 | set_logger() 25 | 26 | 27 | def calc_loss(x_pred, x_output, task_type): 28 | device = x_pred.device 29 | 30 | # binary mark to mask out undefined pixel space 31 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device) 32 | 33 | if task_type == "semantic": 34 | # semantic loss: depth-wise cross entropy 35 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1) 36 | 37 | if task_type == "depth": 38 | # depth loss: l1 norm 39 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero( 40 | binary_mask, as_tuple=False 41 | ).size(0) 42 | 43 | if task_type == "normal": 44 | # normal loss: dot product 45 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero( 46 | binary_mask, as_tuple=False 47 | ).size(0) 48 | 49 | return loss 50 | 51 | 52 | def main(path, lr, bs, device): 53 | # ---- 54 | # Nets 55 | # --- 56 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model] 57 | model = model.to(device) 58 | 59 | # dataset and dataloaders 60 | log_str = ( 61 | "Applying data augmentation on NYUv2." 62 | if args.apply_augmentation 63 | else "Standard training strategy without data augmentation." 64 | ) 65 | logging.info(log_str) 66 | 67 | nyuv2_train_set = NYUv2( 68 | root=path.as_posix(), train=True, augmentation=args.apply_augmentation 69 | ) 70 | nyuv2_test_set = NYUv2(root=path.as_posix(), train=False) 71 | 72 | train_loader = torch.utils.data.DataLoader( 73 | dataset=nyuv2_train_set, batch_size=bs, shuffle=True 74 | ) 75 | 76 | test_loader = torch.utils.data.DataLoader( 77 | dataset=nyuv2_test_set, batch_size=bs, shuffle=False 78 | ) 79 | 80 | # weight method 81 | weight_methods_parameters = extract_weight_method_parameters_from_args(args) 82 | weight_method = WeightMethods( 83 | args.method, n_tasks=3, device=device, **weight_methods_parameters[args.method] 84 | ) 85 | 86 | # optimizer 87 | optimizer = torch.optim.Adam( 88 | [ 89 | dict(params=model.parameters(), lr=lr), 90 | dict(params=weight_method.parameters(), lr=args.method_params_lr), 91 | ], 92 | ) 93 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 94 | 95 | epochs = args.n_epochs 96 | epoch_iter = trange(epochs) 97 | train_batch = len(train_loader) 98 | test_batch = len(test_loader) 99 | avg_cost = np.zeros([epochs, 24], dtype=np.float32) 100 | custom_step = -1 101 | conf_mat = ConfMatrix(model.segnet.class_nb) 102 | for epoch in epoch_iter: 103 | cost = np.zeros(24, dtype=np.float32) 104 | 105 | for j, batch in enumerate(train_loader): 106 | custom_step += 1 107 | 108 | model.train() 109 | optimizer.zero_grad() 110 | 111 | train_data, train_label, train_depth, train_normal = batch 112 | train_data, train_label = train_data.to(device), train_label.long().to( 113 | device 114 | ) 115 | train_depth, train_normal = train_depth.to(device), train_normal.to(device) 116 | 117 | train_pred, features = model(train_data, return_representation=True) 118 | 119 | losses = torch.stack( 120 | ( 121 | calc_loss(train_pred[0], train_label, "semantic"), 122 | calc_loss(train_pred[1], train_depth, "depth"), 123 | calc_loss(train_pred[2], train_normal, "normal"), 124 | ) 125 | ) 126 | 127 | loss, extra_outputs = weight_method.backward( 128 | losses=losses, 129 | shared_parameters=list(model.shared_parameters()), 130 | task_specific_parameters=list(model.task_specific_parameters()), 131 | last_shared_parameters=list(model.last_shared_parameters()), 132 | representation=features, 133 | ) 134 | 135 | optimizer.step() 136 | 137 | # accumulate label prediction for every pixel in training images 138 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten()) 139 | 140 | cost[0] = losses[0].item() 141 | cost[3] = losses[1].item() 142 | cost[4], cost[5] = depth_error(train_pred[1], train_depth) 143 | cost[6] = losses[2].item() 144 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error( 145 | train_pred[2], train_normal 146 | ) 147 | avg_cost[epoch, :12] += cost[:12] / train_batch 148 | 149 | epoch_iter.set_description( 150 | f"[{epoch+1} {j+1}/{train_batch}] semantic loss: {losses[0].item():.3f}, " 151 | f"depth loss: {losses[1].item():.3f}, " 152 | f"normal loss: {losses[2].item():.3f}" 153 | ) 154 | 155 | # scheduler 156 | scheduler.step() 157 | # compute mIoU and acc 158 | avg_cost[epoch, 1:3] = conf_mat.get_metrics() 159 | 160 | # todo: move evaluate to function? 161 | # evaluating test data 162 | model.eval() 163 | conf_mat = ConfMatrix(model.segnet.class_nb) 164 | with torch.no_grad(): # operations inside don't track history 165 | test_dataset = iter(test_loader) 166 | for k in range(test_batch): 167 | test_data, test_label, test_depth, test_normal = next(test_dataset) #test_dataset.next()#.next is deprecated 168 | test_data, test_label = test_data.to(device), test_label.long().to( 169 | device 170 | ) 171 | test_depth, test_normal = test_depth.to(device), test_normal.to(device) 172 | 173 | test_pred = model(test_data) 174 | test_loss = torch.stack( 175 | ( 176 | calc_loss(test_pred[0], test_label, "semantic"), 177 | calc_loss(test_pred[1], test_depth, "depth"), 178 | calc_loss(test_pred[2], test_normal, "normal"), 179 | ) 180 | ) 181 | 182 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten()) 183 | 184 | cost[12] = test_loss[0].item() 185 | cost[15] = test_loss[1].item() 186 | cost[16], cost[17] = depth_error(test_pred[1], test_depth) 187 | cost[18] = test_loss[2].item() 188 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error( 189 | test_pred[2], test_normal 190 | ) 191 | avg_cost[epoch, 12:] += cost[12:] / test_batch 192 | 193 | # compute mIoU and acc 194 | avg_cost[epoch, 13:15] = conf_mat.get_metrics() 195 | 196 | # Test Delta_m 197 | test_delta_m = delta_fn( 198 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]] 199 | ) 200 | 201 | # print results 202 | print( 203 | f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR " 204 | f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m (test)" 205 | ) 206 | print( 207 | f"Epoch: {epoch:04d} | TRAIN: {avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]:.4f} {avg_cost[epoch, 2]:.4f} " 208 | f"| {avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | {avg_cost[epoch, 6]:.4f} " 209 | f"{avg_cost[epoch, 7]:.4f} {avg_cost[epoch, 8]:.4f} {avg_cost[epoch, 9]:.4f} {avg_cost[epoch, 10]:.4f} {avg_cost[epoch, 11]:.4f} || " 210 | f"TEST: {avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]:.4f} {avg_cost[epoch, 14]:.4f} | " 211 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | {avg_cost[epoch, 18]:.4f} " 212 | f"{avg_cost[epoch, 19]:.4f} {avg_cost[epoch, 20]:.4f} {avg_cost[epoch, 21]:.4f} {avg_cost[epoch, 22]:.4f} {avg_cost[epoch, 23]:.4f} " 213 | f"| {test_delta_m:.3f}" 214 | ) 215 | 216 | if wandb.run is not None: 217 | wandb.log({"Train Semantic Loss": avg_cost[epoch, 0]}, step=epoch) 218 | wandb.log({"Train Mean IoU": avg_cost[epoch, 1]}, step=epoch) 219 | wandb.log({"Train Pixel Accuracy": avg_cost[epoch, 2]}, step=epoch) 220 | wandb.log({"Train Depth Loss": avg_cost[epoch, 3]}, step=epoch) 221 | wandb.log({"Train Absolute Error": avg_cost[epoch, 4]}, step=epoch) 222 | wandb.log({"Train Relative Error": avg_cost[epoch, 5]}, step=epoch) 223 | wandb.log({"Train Normal Loss": avg_cost[epoch, 6]}, step=epoch) 224 | wandb.log({"Train Loss Mean": avg_cost[epoch, 7]}, step=epoch) 225 | wandb.log({"Train Loss Med": avg_cost[epoch, 8]}, step=epoch) 226 | wandb.log({"Train Loss <11.25": avg_cost[epoch, 9]}, step=epoch) 227 | wandb.log({"Train Loss <22.5": avg_cost[epoch, 10]}, step=epoch) 228 | wandb.log({"Train Loss <30": avg_cost[epoch, 11]}, step=epoch) 229 | 230 | wandb.log({"Test Semantic Loss": avg_cost[epoch, 12]}, step=epoch) 231 | wandb.log({"Test Mean IoU": avg_cost[epoch, 13]}, step=epoch) 232 | wandb.log({"Test Pixel Accuracy": avg_cost[epoch, 14]}, step=epoch) 233 | wandb.log({"Test Depth Loss": avg_cost[epoch, 15]}, step=epoch) 234 | wandb.log({"Test Absolute Error": avg_cost[epoch, 16]}, step=epoch) 235 | wandb.log({"Test Relative Error": avg_cost[epoch, 17]}, step=epoch) 236 | wandb.log({"Test Normal Loss": avg_cost[epoch, 18]}, step=epoch) 237 | wandb.log({"Test Loss Mean": avg_cost[epoch, 19]}, step=epoch) 238 | wandb.log({"Test Loss Med": avg_cost[epoch, 20]}, step=epoch) 239 | wandb.log({"Test Loss <11.25": avg_cost[epoch, 21]}, step=epoch) 240 | wandb.log({"Test Loss <22.5": avg_cost[epoch, 22]}, step=epoch) 241 | wandb.log({"Test Loss <30": avg_cost[epoch, 23]}, step=epoch) 242 | wandb.log({"Test ∆m": test_delta_m}, step=epoch) 243 | 244 | 245 | if __name__ == "__main__": 246 | parser = ArgumentParser("NYUv2", parents=[common_parser]) 247 | parser.set_defaults( 248 | data_path="dataset", 249 | lr=1e-4, 250 | n_epochs=200, 251 | batch_size=2, 252 | ) 253 | parser.add_argument( 254 | "--model", 255 | type=str, 256 | default="mtan", 257 | choices=["segnet", "mtan"], 258 | help="model type", 259 | ) 260 | parser.add_argument( 261 | "--apply-augmentation", type=str2bool, default=True, help="data augmentations" 262 | ) 263 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.") 264 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.") 265 | args = parser.parse_args() 266 | 267 | # set seed 268 | set_seed(args.seed) 269 | 270 | if args.wandb_project is not None: 271 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args) 272 | 273 | device = get_device(gpus=args.gpu) 274 | main(path=args.data_path, lr=args.lr, bs=args.batch_size, device=device) 275 | 276 | if wandb.run is not None: 277 | wandb.finish() 278 | -------------------------------------------------------------------------------- /experiments/nyuv2/models.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class _SegNet(nn.Module): 9 | """SegNet MTAN""" 10 | 11 | def __init__(self): 12 | super(_SegNet, self).__init__() 13 | # initialise network parameters 14 | filter = [64, 128, 256, 512, 512] 15 | self.class_nb = 13 16 | 17 | # define encoder decoder layers 18 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 19 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 20 | for i in range(4): 21 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 22 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 23 | 24 | # define convolution layer 25 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 26 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 27 | for i in range(4): 28 | if i == 0: 29 | self.conv_block_enc.append( 30 | self.conv_layer([filter[i + 1], filter[i + 1]]) 31 | ) 32 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 33 | else: 34 | self.conv_block_enc.append( 35 | nn.Sequential( 36 | self.conv_layer([filter[i + 1], filter[i + 1]]), 37 | self.conv_layer([filter[i + 1], filter[i + 1]]), 38 | ) 39 | ) 40 | self.conv_block_dec.append( 41 | nn.Sequential( 42 | self.conv_layer([filter[i], filter[i]]), 43 | self.conv_layer([filter[i], filter[i]]), 44 | ) 45 | ) 46 | 47 | # define task attention layers 48 | self.encoder_att = nn.ModuleList( 49 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])] 50 | ) 51 | self.decoder_att = nn.ModuleList( 52 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])] 53 | ) 54 | self.encoder_block_att = nn.ModuleList( 55 | [self.conv_layer([filter[0], filter[1]])] 56 | ) 57 | self.decoder_block_att = nn.ModuleList( 58 | [self.conv_layer([filter[0], filter[0]])] 59 | ) 60 | 61 | for j in range(3): 62 | if j < 2: 63 | self.encoder_att.append( 64 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]) 65 | ) 66 | self.decoder_att.append( 67 | nn.ModuleList( 68 | [self.att_layer([2 * filter[0], filter[0], filter[0]])] 69 | ) 70 | ) 71 | for i in range(4): 72 | self.encoder_att[j].append( 73 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]) 74 | ) 75 | self.decoder_att[j].append( 76 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]) 77 | ) 78 | 79 | for i in range(4): 80 | if i < 3: 81 | self.encoder_block_att.append( 82 | self.conv_layer([filter[i + 1], filter[i + 2]]) 83 | ) 84 | self.decoder_block_att.append( 85 | self.conv_layer([filter[i + 1], filter[i]]) 86 | ) 87 | else: 88 | self.encoder_block_att.append( 89 | self.conv_layer([filter[i + 1], filter[i + 1]]) 90 | ) 91 | self.decoder_block_att.append( 92 | self.conv_layer([filter[i + 1], filter[i + 1]]) 93 | ) 94 | 95 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True) 96 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True) 97 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True) 98 | 99 | # define pooling and unpooling functions 100 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 101 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.xavier_normal_(m.weight) 106 | nn.init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.Linear): 111 | nn.init.xavier_normal_(m.weight) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def shared_modules(self): 115 | return [ 116 | self.encoder_block, 117 | self.decoder_block, 118 | self.conv_block_enc, 119 | self.conv_block_dec, 120 | # self.encoder_att, self.decoder_att, 121 | self.encoder_block_att, 122 | self.decoder_block_att, 123 | self.down_sampling, 124 | self.up_sampling, 125 | ] 126 | 127 | def zero_grad_shared_modules(self): 128 | for mm in self.shared_modules(): 129 | mm.zero_grad() 130 | 131 | def conv_layer(self, channel, pred=False): 132 | if not pred: 133 | conv_block = nn.Sequential( 134 | nn.Conv2d( 135 | in_channels=channel[0], 136 | out_channels=channel[1], 137 | kernel_size=3, 138 | padding=1, 139 | ), 140 | nn.BatchNorm2d(num_features=channel[1]), 141 | nn.ReLU(inplace=True), 142 | ) 143 | else: 144 | conv_block = nn.Sequential( 145 | nn.Conv2d( 146 | in_channels=channel[0], 147 | out_channels=channel[0], 148 | kernel_size=3, 149 | padding=1, 150 | ), 151 | nn.Conv2d( 152 | in_channels=channel[0], 153 | out_channels=channel[1], 154 | kernel_size=1, 155 | padding=0, 156 | ), 157 | ) 158 | return conv_block 159 | 160 | def att_layer(self, channel): 161 | att_block = nn.Sequential( 162 | nn.Conv2d( 163 | in_channels=channel[0], 164 | out_channels=channel[1], 165 | kernel_size=1, 166 | padding=0, 167 | ), 168 | nn.BatchNorm2d(channel[1]), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d( 171 | in_channels=channel[1], 172 | out_channels=channel[2], 173 | kernel_size=1, 174 | padding=0, 175 | ), 176 | nn.BatchNorm2d(channel[2]), 177 | nn.Sigmoid(), 178 | ) 179 | return att_block 180 | 181 | def forward(self, x): 182 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 183 | [0] * 5 for _ in range(5) 184 | ) 185 | for i in range(5): 186 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 187 | 188 | # define attention list for tasks 189 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2)) 190 | for i in range(3): 191 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2)) 192 | for i in range(3): 193 | for j in range(5): 194 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2)) 195 | 196 | # define global shared network 197 | for i in range(5): 198 | if i == 0: 199 | g_encoder[i][0] = self.encoder_block[i](x) 200 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 201 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 202 | else: 203 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 204 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 205 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 206 | 207 | for i in range(5): 208 | if i == 0: 209 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 210 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 211 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 212 | else: 213 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 214 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 215 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 216 | 217 | # define task dependent attention module 218 | for i in range(3): 219 | for j in range(5): 220 | if j == 0: 221 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0]) 222 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 223 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 224 | atten_encoder[i][j][1] 225 | ) 226 | atten_encoder[i][j][2] = F.max_pool2d( 227 | atten_encoder[i][j][2], kernel_size=2, stride=2 228 | ) 229 | else: 230 | atten_encoder[i][j][0] = self.encoder_att[i][j]( 231 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1) 232 | ) 233 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 234 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 235 | atten_encoder[i][j][1] 236 | ) 237 | atten_encoder[i][j][2] = F.max_pool2d( 238 | atten_encoder[i][j][2], kernel_size=2, stride=2 239 | ) 240 | 241 | for j in range(5): 242 | if j == 0: 243 | atten_decoder[i][j][0] = F.interpolate( 244 | atten_encoder[i][-1][-1], 245 | scale_factor=2, 246 | mode="bilinear", 247 | align_corners=True, 248 | ) 249 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 250 | atten_decoder[i][j][0] 251 | ) 252 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 253 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 254 | ) 255 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 256 | else: 257 | atten_decoder[i][j][0] = F.interpolate( 258 | atten_decoder[i][j - 1][2], 259 | scale_factor=2, 260 | mode="bilinear", 261 | align_corners=True, 262 | ) 263 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 264 | atten_decoder[i][j][0] 265 | ) 266 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 267 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 268 | ) 269 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 270 | 271 | # define task prediction layers 272 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1) 273 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1]) 274 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1]) 275 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 276 | 277 | return ( 278 | [t1_pred, t2_pred, t3_pred], 279 | ( 280 | atten_decoder[0][-1][-1], 281 | atten_decoder[1][-1][-1], 282 | atten_decoder[2][-1][-1], 283 | ), 284 | ) 285 | 286 | 287 | class SegNetMtan(nn.Module): 288 | def __init__(self): 289 | super().__init__() 290 | self.segnet = _SegNet() 291 | 292 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 293 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 294 | 295 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 296 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 297 | 298 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 299 | """Parameters of the last shared layer. 300 | Returns 301 | ------- 302 | """ 303 | return [] 304 | 305 | def forward(self, x, return_representation=False): 306 | if return_representation: 307 | return self.segnet(x) 308 | else: 309 | pred, rep = self.segnet(x) 310 | return pred 311 | 312 | 313 | class SegNetSplit(nn.Module): 314 | def __init__(self, model_type="standard"): 315 | super(SegNetSplit, self).__init__() 316 | # initialise network parameters 317 | assert model_type in ["standard", "wide", "deep"] 318 | self.model_type = model_type 319 | if self.model_type == "wide": 320 | filter = [64, 128, 256, 512, 1024] 321 | else: 322 | filter = [64, 128, 256, 512, 512] 323 | 324 | self.class_nb = 13 325 | 326 | # define encoder decoder layers 327 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 328 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 329 | for i in range(4): 330 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 331 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 332 | 333 | # define convolution layer 334 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 335 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 336 | for i in range(4): 337 | if i == 0: 338 | self.conv_block_enc.append( 339 | self.conv_layer([filter[i + 1], filter[i + 1]]) 340 | ) 341 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 342 | else: 343 | self.conv_block_enc.append( 344 | nn.Sequential( 345 | self.conv_layer([filter[i + 1], filter[i + 1]]), 346 | self.conv_layer([filter[i + 1], filter[i + 1]]), 347 | ) 348 | ) 349 | self.conv_block_dec.append( 350 | nn.Sequential( 351 | self.conv_layer([filter[i], filter[i]]), 352 | self.conv_layer([filter[i], filter[i]]), 353 | ) 354 | ) 355 | 356 | # define task specific layers 357 | self.pred_task1 = nn.Sequential( 358 | nn.Conv2d( 359 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 360 | ), 361 | nn.Conv2d( 362 | in_channels=filter[0], 363 | out_channels=self.class_nb, 364 | kernel_size=1, 365 | padding=0, 366 | ), 367 | ) 368 | self.pred_task2 = nn.Sequential( 369 | nn.Conv2d( 370 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 371 | ), 372 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0), 373 | ) 374 | self.pred_task3 = nn.Sequential( 375 | nn.Conv2d( 376 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 377 | ), 378 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0), 379 | ) 380 | 381 | # define pooling and unpooling functions 382 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 383 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 384 | 385 | for m in self.modules(): 386 | if isinstance(m, nn.Conv2d): 387 | nn.init.xavier_normal_(m.weight) 388 | nn.init.constant_(m.bias, 0) 389 | elif isinstance(m, nn.BatchNorm2d): 390 | nn.init.constant_(m.weight, 1) 391 | nn.init.constant_(m.bias, 0) 392 | elif isinstance(m, nn.Linear): 393 | nn.init.xavier_normal_(m.weight) 394 | nn.init.constant_(m.bias, 0) 395 | 396 | # define convolutional block 397 | def conv_layer(self, channel): 398 | if self.model_type == "deep": 399 | conv_block = nn.Sequential( 400 | nn.Conv2d( 401 | in_channels=channel[0], 402 | out_channels=channel[1], 403 | kernel_size=3, 404 | padding=1, 405 | ), 406 | nn.BatchNorm2d(num_features=channel[1]), 407 | nn.ReLU(inplace=True), 408 | nn.Conv2d( 409 | in_channels=channel[1], 410 | out_channels=channel[1], 411 | kernel_size=3, 412 | padding=1, 413 | ), 414 | nn.BatchNorm2d(num_features=channel[1]), 415 | nn.ReLU(inplace=True), 416 | ) 417 | else: 418 | conv_block = nn.Sequential( 419 | nn.Conv2d( 420 | in_channels=channel[0], 421 | out_channels=channel[1], 422 | kernel_size=3, 423 | padding=1, 424 | ), 425 | nn.BatchNorm2d(num_features=channel[1]), 426 | nn.ReLU(inplace=True), 427 | ) 428 | return conv_block 429 | 430 | def forward(self, x): 431 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 432 | [0] * 5 for _ in range(5) 433 | ) 434 | for i in range(5): 435 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 436 | 437 | # global shared encoder-decoder network 438 | for i in range(5): 439 | if i == 0: 440 | g_encoder[i][0] = self.encoder_block[i](x) 441 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 442 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 443 | else: 444 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 445 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 446 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 447 | 448 | for i in range(5): 449 | if i == 0: 450 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 451 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 452 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 453 | else: 454 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 455 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 456 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 457 | 458 | # define task prediction layers 459 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1) 460 | t2_pred = self.pred_task2(g_decoder[i][1]) 461 | t3_pred = self.pred_task3(g_decoder[i][1]) 462 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 463 | 464 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][ 465 | 1 466 | ] # NOTE: last element is representation 467 | 468 | 469 | class SegNet(nn.Module): 470 | def __init__(self): 471 | super().__init__() 472 | self.segnet = SegNetSplit() 473 | 474 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 475 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 476 | 477 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 478 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 479 | 480 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 481 | """Parameters of the last shared layer. 482 | Returns 483 | ------- 484 | """ 485 | return self.segnet.conv_block_dec[-5].parameters() 486 | 487 | def forward(self, x, return_representation=False): 488 | if return_representation: 489 | return self.segnet(x) 490 | else: 491 | pred, rep = self.segnet(x) 492 | return pred 493 | -------------------------------------------------------------------------------- /methods/weight_methods.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from abc import abstractmethod 4 | from typing import Dict, List, Tuple, Union 5 | 6 | import cvxpy as cp 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from scipy.optimize import minimize 11 | 12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers 13 | 14 | 15 | class WeightMethod: 16 | def __init__(self, n_tasks: int, device: torch.device): 17 | super().__init__() 18 | self.n_tasks = n_tasks 19 | self.device = device 20 | 21 | @abstractmethod 22 | def get_weighted_loss( 23 | self, 24 | losses: torch.Tensor, 25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 26 | task_specific_parameters: Union[ 27 | List[torch.nn.parameter.Parameter], torch.Tensor 28 | ], 29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 31 | **kwargs, 32 | ): 33 | pass 34 | 35 | def backward( 36 | self, 37 | losses: torch.Tensor, 38 | shared_parameters: Union[ 39 | List[torch.nn.parameter.Parameter], torch.Tensor 40 | ] = None, 41 | task_specific_parameters: Union[ 42 | List[torch.nn.parameter.Parameter], torch.Tensor 43 | ] = None, 44 | last_shared_parameters: Union[ 45 | List[torch.nn.parameter.Parameter], torch.Tensor 46 | ] = None, 47 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 48 | **kwargs, 49 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 50 | """ 51 | 52 | Parameters 53 | ---------- 54 | losses : 55 | shared_parameters : 56 | task_specific_parameters : 57 | last_shared_parameters : parameters of last shared layer/block 58 | representation : shared representation 59 | kwargs : 60 | 61 | Returns 62 | ------- 63 | Loss, extra outputs 64 | """ 65 | loss, extra_outputs = self.get_weighted_loss( 66 | losses=losses, 67 | shared_parameters=shared_parameters, 68 | task_specific_parameters=task_specific_parameters, 69 | last_shared_parameters=last_shared_parameters, 70 | representation=representation, 71 | **kwargs, 72 | ) 73 | loss.backward() 74 | return loss, extra_outputs 75 | 76 | def __call__( 77 | self, 78 | losses: torch.Tensor, 79 | shared_parameters: Union[ 80 | List[torch.nn.parameter.Parameter], torch.Tensor 81 | ] = None, 82 | task_specific_parameters: Union[ 83 | List[torch.nn.parameter.Parameter], torch.Tensor 84 | ] = None, 85 | **kwargs, 86 | ): 87 | return self.backward( 88 | losses=losses, 89 | shared_parameters=shared_parameters, 90 | task_specific_parameters=task_specific_parameters, 91 | **kwargs, 92 | ) 93 | 94 | def parameters(self) -> List[torch.Tensor]: 95 | """return learnable parameters""" 96 | return [] 97 | 98 | 99 | class NashMTL(WeightMethod): 100 | def __init__( 101 | self, 102 | n_tasks: int, 103 | device: torch.device, 104 | max_norm: float = 1.0, 105 | update_weights_every: int = 1, 106 | optim_niter=20, 107 | ): 108 | super(NashMTL, self).__init__( 109 | n_tasks=n_tasks, 110 | device=device, 111 | ) 112 | 113 | self.optim_niter = optim_niter 114 | self.update_weights_every = update_weights_every 115 | self.max_norm = max_norm 116 | 117 | self.prvs_alpha_param = None 118 | self.normalization_factor = np.ones((1,)) 119 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks) 120 | self.step = 0.0 121 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) 122 | 123 | def _stop_criteria(self, gtg, alpha_t): 124 | return ( 125 | (self.alpha_param.value is None) 126 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) 127 | or ( 128 | np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) 129 | < 1e-6 130 | ) 131 | ) 132 | 133 | def solve_optimization(self, gtg: np.array): 134 | self.G_param.value = gtg 135 | self.normalization_factor_param.value = self.normalization_factor 136 | 137 | alpha_t = self.prvs_alpha 138 | for _ in range(self.optim_niter): 139 | self.alpha_param.value = alpha_t 140 | self.prvs_alpha_param.value = alpha_t 141 | 142 | try: 143 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) 144 | except: 145 | self.alpha_param.value = self.prvs_alpha_param.value 146 | 147 | if self._stop_criteria(gtg, alpha_t): 148 | break 149 | 150 | alpha_t = self.alpha_param.value 151 | 152 | if alpha_t is not None: 153 | self.prvs_alpha = alpha_t 154 | 155 | return self.prvs_alpha 156 | 157 | def _calc_phi_alpha_linearization(self): 158 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param 159 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param 160 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) 161 | return phi_alpha 162 | 163 | def _init_optim_problem(self): 164 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True) 165 | self.prvs_alpha_param = cp.Parameter( 166 | shape=(self.n_tasks,), value=self.prvs_alpha 167 | ) 168 | self.G_param = cp.Parameter( 169 | shape=(self.n_tasks, self.n_tasks), value=self.init_gtg 170 | ) 171 | self.normalization_factor_param = cp.Parameter( 172 | shape=(1,), value=np.array([1.0]) 173 | ) 174 | 175 | self.phi_alpha = self._calc_phi_alpha_linearization() 176 | 177 | G_alpha = self.G_param @ self.alpha_param 178 | constraint = [] 179 | for i in range(self.n_tasks): 180 | constraint.append( 181 | -cp.log(self.alpha_param[i] * self.normalization_factor_param) 182 | - cp.log(G_alpha[i]) 183 | <= 0 184 | ) 185 | obj = cp.Minimize( 186 | cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param 187 | ) 188 | self.prob = cp.Problem(obj, constraint) 189 | 190 | def get_weighted_loss( 191 | self, 192 | losses, 193 | shared_parameters, 194 | **kwargs, 195 | ): 196 | """ 197 | 198 | Parameters 199 | ---------- 200 | losses : 201 | shared_parameters : shared parameters 202 | kwargs : 203 | 204 | Returns 205 | ------- 206 | 207 | """ 208 | 209 | extra_outputs = dict() 210 | if self.step == 0: 211 | self._init_optim_problem() 212 | 213 | if (self.step % self.update_weights_every) == 0: 214 | self.step += 1 215 | 216 | grads = {} 217 | for i, loss in enumerate(losses): 218 | g = list( 219 | torch.autograd.grad( 220 | loss, 221 | shared_parameters, 222 | retain_graph=True, 223 | ) 224 | ) 225 | grad = torch.cat([torch.flatten(grad) for grad in g]) 226 | grads[i] = grad 227 | 228 | G = torch.stack(tuple(v for v in grads.values())) 229 | GTG = torch.mm(G, G.t()) 230 | 231 | self.normalization_factor = ( 232 | torch.norm(GTG).detach().cpu().numpy().reshape((1,)) 233 | ) 234 | GTG = GTG / self.normalization_factor.item() 235 | alpha = self.solve_optimization(GTG.cpu().detach().numpy()) 236 | alpha = torch.from_numpy(alpha) 237 | 238 | else: 239 | self.step += 1 240 | alpha = self.prvs_alpha 241 | 242 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))]) 243 | extra_outputs["weights"] = alpha 244 | return weighted_loss, extra_outputs 245 | 246 | def backward( 247 | self, 248 | losses: torch.Tensor, 249 | shared_parameters: Union[ 250 | List[torch.nn.parameter.Parameter], torch.Tensor 251 | ] = None, 252 | task_specific_parameters: Union[ 253 | List[torch.nn.parameter.Parameter], torch.Tensor 254 | ] = None, 255 | last_shared_parameters: Union[ 256 | List[torch.nn.parameter.Parameter], torch.Tensor 257 | ] = None, 258 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 259 | **kwargs, 260 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 261 | loss, extra_outputs = self.get_weighted_loss( 262 | losses=losses, 263 | shared_parameters=shared_parameters, 264 | **kwargs, 265 | ) 266 | loss.backward() 267 | 268 | # make sure the solution for shared params has norm <= self.eps 269 | if self.max_norm > 0: 270 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm) 271 | 272 | return loss, extra_outputs 273 | 274 | 275 | class LinearScalarization(WeightMethod): 276 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 277 | 278 | def __init__( 279 | self, 280 | n_tasks: int, 281 | device: torch.device, 282 | task_weights: Union[List[float], torch.Tensor] = None, 283 | ): 284 | super().__init__(n_tasks, device=device) 285 | if task_weights is None: 286 | task_weights = torch.ones((n_tasks,)) 287 | if not isinstance(task_weights, torch.Tensor): 288 | task_weights = torch.tensor(task_weights) 289 | assert len(task_weights) == n_tasks 290 | self.task_weights = task_weights.to(device) 291 | 292 | def get_weighted_loss(self, losses, **kwargs): 293 | loss = torch.sum(losses * self.task_weights) 294 | return loss, dict(weights=self.task_weights) 295 | 296 | 297 | class ScaleInvariantLinearScalarization(WeightMethod): 298 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 299 | 300 | def __init__( 301 | self, 302 | n_tasks: int, 303 | device: torch.device, 304 | task_weights: Union[List[float], torch.Tensor] = None, 305 | ): 306 | super().__init__(n_tasks, device=device) 307 | if task_weights is None: 308 | task_weights = torch.ones((n_tasks,)) 309 | if not isinstance(task_weights, torch.Tensor): 310 | task_weights = torch.tensor(task_weights) 311 | assert len(task_weights) == n_tasks 312 | self.task_weights = task_weights.to(device) 313 | 314 | def get_weighted_loss(self, losses, **kwargs): 315 | loss = torch.sum(torch.log(losses) * self.task_weights) 316 | return loss, dict(weights=self.task_weights) 317 | 318 | 319 | class MGDA(WeightMethod): 320 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization 321 | Ozan Sener, Vladlen Koltun 322 | Neural Information Processing Systems (NeurIPS) 2018 323 | https://github.com/intel-isl/MultiObjectiveOptimization 324 | 325 | """ 326 | 327 | def __init__( 328 | self, n_tasks, device: torch.device, params="shared", normalization="none" 329 | ): 330 | super().__init__(n_tasks, device=device) 331 | self.solver = MinNormSolver() 332 | assert params in ["shared", "last", "rep"] 333 | self.params = params 334 | assert normalization in ["norm", "loss", "loss+", "none"] 335 | self.normalization = normalization 336 | 337 | @staticmethod 338 | def _flattening(grad): 339 | return torch.cat( 340 | tuple( 341 | g.reshape( 342 | -1, 343 | ) 344 | for i, g in enumerate(grad) 345 | ), 346 | dim=0, 347 | ) 348 | 349 | def get_weighted_loss( 350 | self, 351 | losses, 352 | shared_parameters=None, 353 | last_shared_parameters=None, 354 | representation=None, 355 | **kwargs, 356 | ): 357 | """ 358 | 359 | Parameters 360 | ---------- 361 | losses : 362 | shared_parameters : 363 | last_shared_parameters : 364 | representation : 365 | kwargs : 366 | 367 | Returns 368 | ------- 369 | 370 | """ 371 | # Our code 372 | grads = {} 373 | params = dict( 374 | rep=representation, shared=shared_parameters, last=last_shared_parameters 375 | )[self.params] 376 | for i, loss in enumerate(losses): 377 | g = list( 378 | torch.autograd.grad( 379 | loss, 380 | params, 381 | retain_graph=True, 382 | ) 383 | ) 384 | # Normalize all gradients, this is optional and not included in the paper. 385 | 386 | grads[i] = [torch.flatten(grad) for grad in g] 387 | 388 | gn = gradient_normalizers(grads, losses, self.normalization) 389 | for t in range(self.n_tasks): 390 | for gr_i in range(len(grads[t])): 391 | grads[t][gr_i] = grads[t][gr_i] / gn[t] 392 | 393 | sol, min_norm = self.solver.find_min_norm_element( 394 | [grads[t] for t in range(len(grads))] 395 | ) 396 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks 397 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))]) 398 | 399 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32))) 400 | 401 | 402 | class STL(WeightMethod): 403 | """Single task learning""" 404 | 405 | def __init__(self, n_tasks, device: torch.device, main_task): 406 | super().__init__(n_tasks, device=device) 407 | self.main_task = main_task 408 | self.weights = torch.zeros(n_tasks, device=device) 409 | self.weights[main_task] = 1.0 410 | 411 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 412 | assert len(losses) == self.n_tasks 413 | loss = losses[self.main_task] 414 | 415 | return loss, dict(weights=self.weights) 416 | 417 | 418 | class Uncertainty(WeightMethod): 419 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics` 420 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb 421 | """ 422 | 423 | def __init__(self, n_tasks, device: torch.device): 424 | super().__init__(n_tasks, device=device) 425 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 426 | 427 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 428 | loss = sum( 429 | [ 430 | 0.5 * (torch.exp(-logs) * loss + logs) 431 | for loss, logs in zip(losses, self.logsigma) 432 | ] 433 | ) 434 | 435 | return loss, dict( 436 | weights=torch.exp(-self.logsigma) 437 | ) # NOTE: not exactly task weights 438 | 439 | def parameters(self) -> List[torch.Tensor]: 440 | return [self.logsigma] 441 | 442 | 443 | class PCGrad(WeightMethod): 444 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py 445 | 446 | @misc{Pytorch-PCGrad, 447 | author = {Wei-Cheng Tseng}, 448 | title = {WeiChengTseng/Pytorch-PCGrad}, 449 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git}, 450 | year = {2020} 451 | } 452 | 453 | """ 454 | 455 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"): 456 | super().__init__(n_tasks, device=device) 457 | assert reduction in ["mean", "sum"] 458 | self.reduction = reduction 459 | 460 | def get_weighted_loss( 461 | self, 462 | losses: torch.Tensor, 463 | shared_parameters: Union[ 464 | List[torch.nn.parameter.Parameter], torch.Tensor 465 | ] = None, 466 | task_specific_parameters: Union[ 467 | List[torch.nn.parameter.Parameter], torch.Tensor 468 | ] = None, 469 | **kwargs, 470 | ): 471 | raise NotImplementedError 472 | 473 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None): 474 | # shared part 475 | shared_grads = [] 476 | for l in losses: 477 | shared_grads.append( 478 | torch.autograd.grad(l, shared_parameters, retain_graph=True) 479 | ) 480 | 481 | if isinstance(shared_parameters, torch.Tensor): 482 | shared_parameters = [shared_parameters] 483 | non_conflict_shared_grads = self._project_conflicting(shared_grads) 484 | for p, g in zip(shared_parameters, non_conflict_shared_grads): 485 | p.grad = g 486 | 487 | # task specific part 488 | if task_specific_parameters is not None: 489 | task_specific_grads = torch.autograd.grad( 490 | losses.sum(), task_specific_parameters 491 | ) 492 | if isinstance(task_specific_parameters, torch.Tensor): 493 | task_specific_parameters = [task_specific_parameters] 494 | for p, g in zip(task_specific_parameters, task_specific_grads): 495 | p.grad = g 496 | 497 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]): 498 | pc_grad = copy.deepcopy(grads) 499 | for g_i in pc_grad: 500 | random.shuffle(grads) 501 | for g_j in grads: 502 | g_i_g_j = sum( 503 | [ 504 | torch.dot(torch.flatten(grad_i), torch.flatten(grad_j)) 505 | for grad_i, grad_j in zip(g_i, g_j) 506 | ] 507 | ) 508 | if g_i_g_j < 0: 509 | g_j_norm_square = ( 510 | torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2 511 | ) 512 | for grad_i, grad_j in zip(g_i, g_j): 513 | grad_i -= g_i_g_j * grad_j / g_j_norm_square 514 | 515 | merged_grad = [sum(g) for g in zip(*pc_grad)] 516 | if self.reduction == "mean": 517 | merged_grad = [g / self.n_tasks for g in merged_grad] 518 | 519 | return merged_grad 520 | 521 | def backward( 522 | self, 523 | losses: torch.Tensor, 524 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 525 | shared_parameters: Union[ 526 | List[torch.nn.parameter.Parameter], torch.Tensor 527 | ] = None, 528 | task_specific_parameters: Union[ 529 | List[torch.nn.parameter.Parameter], torch.Tensor 530 | ] = None, 531 | **kwargs, 532 | ): 533 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters) 534 | return None, {} # NOTE: to align with all other weight methods 535 | 536 | 537 | class CAGrad(WeightMethod): 538 | def __init__(self, n_tasks, device: torch.device, c=0.4): 539 | super().__init__(n_tasks, device=device) 540 | self.c = c 541 | 542 | def get_weighted_loss( 543 | self, 544 | losses, 545 | shared_parameters, 546 | **kwargs, 547 | ): 548 | """ 549 | Parameters 550 | ---------- 551 | losses : 552 | shared_parameters : shared parameters 553 | kwargs : 554 | Returns 555 | ------- 556 | """ 557 | # NOTE: we allow only shared params for now. Need to see paper for other options. 558 | grad_dims = [] 559 | for param in shared_parameters: 560 | grad_dims.append(param.data.numel()) 561 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device) 562 | 563 | for i in range(self.n_tasks): 564 | if i < (self.n_tasks - 1): 565 | losses[i].backward(retain_graph=True) 566 | else: 567 | losses[i].backward() 568 | self.grad2vec(shared_parameters, grads, grad_dims, i) 569 | # multi_task_model.zero_grad_shared_modules() 570 | for p in shared_parameters: 571 | p.grad = None 572 | 573 | g = self.cagrad(grads, alpha=self.c, rescale=1) 574 | self.overwrite_grad(shared_parameters, g, grad_dims) 575 | 576 | def cagrad(self, grads, alpha=0.5, rescale=1): 577 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks] 578 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient 579 | 580 | x_start = np.ones(self.n_tasks) / self.n_tasks 581 | bnds = tuple((0, 1) for x in x_start) 582 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)} 583 | A = GG.numpy() 584 | b = x_start.copy() 585 | c = (alpha * g0_norm + 1e-8).item() 586 | 587 | def objfn(x): 588 | return ( 589 | x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1)) 590 | + c 591 | * np.sqrt( 592 | x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1)) 593 | + 1e-8 594 | ) 595 | ).sum() 596 | 597 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 598 | w_cpu = res.x 599 | ww = torch.Tensor(w_cpu).to(grads.device) 600 | gw = (grads * ww.view(1, -1)).sum(1) 601 | gw_norm = gw.norm() 602 | lmbda = c / (gw_norm + 1e-8) 603 | g = grads.mean(1) + lmbda * gw 604 | if rescale == 0: 605 | return g 606 | elif rescale == 1: 607 | return g / (1 + alpha ** 2) 608 | else: 609 | return g / (1 + alpha) 610 | 611 | @staticmethod 612 | def grad2vec(shared_params, grads, grad_dims, task): 613 | # store the gradients 614 | grads[:, task].fill_(0.0) 615 | cnt = 0 616 | # for mm in m.shared_modules(): 617 | # for p in mm.parameters(): 618 | 619 | for param in shared_params: 620 | grad = param.grad 621 | if grad is not None: 622 | grad_cur = grad.data.detach().clone() 623 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 624 | en = sum(grad_dims[: cnt + 1]) 625 | grads[beg:en, task].copy_(grad_cur.data.view(-1)) 626 | cnt += 1 627 | 628 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims): 629 | newgrad = newgrad * self.n_tasks # to match the sum loss 630 | cnt = 0 631 | 632 | # for mm in m.shared_modules(): 633 | # for param in mm.parameters(): 634 | for param in shared_parameters: 635 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 636 | en = sum(grad_dims[: cnt + 1]) 637 | this_grad = newgrad[beg:en].contiguous().view(param.data.size()) 638 | param.grad = this_grad.data.clone() 639 | cnt += 1 640 | 641 | def backward( 642 | self, 643 | losses: torch.Tensor, 644 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 645 | shared_parameters: Union[ 646 | List[torch.nn.parameter.Parameter], torch.Tensor 647 | ] = None, 648 | task_specific_parameters: Union[ 649 | List[torch.nn.parameter.Parameter], torch.Tensor 650 | ] = None, 651 | **kwargs, 652 | ): 653 | self.get_weighted_loss(losses, shared_parameters) 654 | return None, {} # NOTE: to align with all other weight methods 655 | 656 | 657 | class RLW(WeightMethod): 658 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf""" 659 | 660 | def __init__(self, n_tasks, device: torch.device): 661 | super().__init__(n_tasks, device=device) 662 | 663 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 664 | assert len(losses) == self.n_tasks 665 | weight = (F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 666 | loss = torch.sum(losses * weight) 667 | 668 | return loss, dict(weights=weight) 669 | 670 | 671 | class IMTLG(WeightMethod): 672 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr""" 673 | 674 | def __init__(self, n_tasks, device: torch.device): 675 | super().__init__(n_tasks, device=device) 676 | 677 | def get_weighted_loss( 678 | self, 679 | losses, 680 | shared_parameters, 681 | **kwargs, 682 | ): 683 | grads = {} 684 | norm_grads = {} 685 | 686 | for i, loss in enumerate(losses): 687 | g = list( 688 | torch.autograd.grad( 689 | loss, 690 | shared_parameters, 691 | retain_graph=True, 692 | ) 693 | ) 694 | grad = torch.cat([torch.flatten(grad) for grad in g]) 695 | norm_term = torch.norm(grad) 696 | 697 | grads[i] = grad 698 | norm_grads[i] = grad / norm_term 699 | 700 | G = torch.stack(tuple(v for v in grads.values())) 701 | D = ( 702 | G[ 703 | 0, 704 | ] 705 | - G[ 706 | 1:, 707 | ] 708 | ) 709 | 710 | U = torch.stack(tuple(v for v in norm_grads.values())) 711 | U = ( 712 | U[ 713 | 0, 714 | ] 715 | - U[ 716 | 1:, 717 | ] 718 | ) 719 | first_element = torch.matmul( 720 | G[ 721 | 0, 722 | ], 723 | U.t(), 724 | ) 725 | try: 726 | second_element = torch.inverse(torch.matmul(D, U.t())) 727 | except: 728 | # workaround for cases where matrix is singular 729 | second_element = torch.inverse( 730 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8 731 | + torch.matmul(D, U.t()) 732 | ) 733 | 734 | alpha_ = torch.matmul(first_element, second_element) 735 | alpha = torch.cat( 736 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_) 737 | ) 738 | 739 | loss = torch.sum(losses * alpha) 740 | 741 | return loss, dict(weights=alpha) 742 | 743 | 744 | class DynamicWeightAverage(WeightMethod): 745 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`. 746 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242 747 | """ 748 | 749 | def __init__( 750 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0 751 | ): 752 | """ 753 | 754 | Parameters 755 | ---------- 756 | n_tasks : 757 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses 758 | temp : 759 | """ 760 | super().__init__(n_tasks, device=device) 761 | self.iteration_window = iteration_window 762 | self.temp = temp 763 | self.running_iterations = 0 764 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32) 765 | self.weights = np.ones(n_tasks, dtype=np.float32) 766 | 767 | def get_weighted_loss(self, losses, **kwargs): 768 | 769 | cost = losses.detach().cpu().numpy() 770 | 771 | # update costs - fifo 772 | self.costs[:-1, :] = self.costs[1:, :] 773 | self.costs[-1, :] = cost 774 | 775 | if self.running_iterations > self.iteration_window: 776 | ws = self.costs[self.iteration_window :, :].mean(0) / self.costs[ 777 | : self.iteration_window, : 778 | ].mean(0) 779 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / ( 780 | np.exp(ws / self.temp) 781 | ).sum() 782 | 783 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to( 784 | losses.device 785 | ) 786 | loss = (task_weights * losses).mean() 787 | 788 | self.running_iterations += 1 789 | 790 | return loss, dict(weights=task_weights) 791 | 792 | 793 | class WeightMethods: 794 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs): 795 | """ 796 | :param method: 797 | """ 798 | assert method in list(METHODS.keys()), f"unknown method {method}." 799 | 800 | self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 801 | 802 | def get_weighted_loss(self, losses, **kwargs): 803 | return self.method.get_weighted_loss(losses, **kwargs) 804 | 805 | def backward( 806 | self, losses, **kwargs 807 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 808 | return self.method.backward(losses, **kwargs) 809 | 810 | def __ceil__(self, losses, **kwargs): 811 | return self.backward(losses, **kwargs) 812 | 813 | def parameters(self): 814 | return self.method.parameters() 815 | 816 | 817 | METHODS = dict( 818 | stl=STL, 819 | ls=LinearScalarization, 820 | uw=Uncertainty, 821 | pcgrad=PCGrad, 822 | mgda=MGDA, 823 | cagrad=CAGrad, 824 | nashmtl=NashMTL, 825 | scaleinvls=ScaleInvariantLinearScalarization, 826 | rlw=RLW, 827 | imtl=IMTLG, 828 | dwa=DynamicWeightAverage, 829 | ) 830 | --------------------------------------------------------------------------------