├── experiments
├── __init__.py
├── toy
│ ├── __init__.py
│ ├── README.md
│ ├── utils.py
│ ├── problem.py
│ └── trainer.py
├── nyuv2
│ ├── __init__.py
│ ├── dataset
│ │ └── .gitinclude
│ ├── README.md
│ ├── utils.py
│ ├── data.py
│ ├── trainer.py
│ └── models.py
├── quantum_chemistry
│ ├── __init__.py
│ ├── README.md
│ ├── models.py
│ ├── utils.py
│ └── trainer.py
└── utils.py
├── .isort.cfg
├── misc
└── toy_pareto_2d.png
├── methods
├── __init__.py
├── min_norm_solvers.py
└── weight_methods.py
├── requirements.txt
├── LICENSE
├── setup.py
├── .gitignore
└── README.md
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/toy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/nyuv2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/nyuv2/dataset/.gitinclude:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | multi_line_output = 3
4 |
--------------------------------------------------------------------------------
/misc/toy_pareto_2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AvivNavon/nash-mtl/HEAD/misc/toy_pareto_2d.png
--------------------------------------------------------------------------------
/experiments/toy/README.md:
--------------------------------------------------------------------------------
1 | # Illustrative Example
2 |
3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad).
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/README.md:
--------------------------------------------------------------------------------
1 | # QM9 Experiment
2 |
3 | Modification of the example code in [PyG](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py).
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from methods.weight_methods import (
2 | METHODS,
3 | MGDA,
4 | STL,
5 | LinearScalarization,
6 | NashMTL,
7 | PCGrad,
8 | Uncertainty,
9 | )
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.2.1
2 | numpy>=1.18.2
3 | torch>=1.4.0
4 | torchvision>=0.8.0
5 | cvxpy==1.5.4
6 | tqdm>=4.45.0
7 | pandas
8 | scikit-learn
9 | seaborn
10 | wandb
11 | plotly
12 |
--------------------------------------------------------------------------------
/experiments/nyuv2/README.md:
--------------------------------------------------------------------------------
1 | # NYUv2 Experiment
2 |
3 | Modification of the code in [CAGrad](https://github.com/Cranial-XIX/CAGrad) and [MTAN](https://github.com/lorenmt/mtan).
4 |
5 | ## Dataset
6 |
7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `.dataset/train` and `./dataset/val`.
8 |
9 | ## Evaluation
10 |
11 | To align with previous work on MTL [Liu et al. (2019)](https://arxiv.org/abs/1803.10704); [Yu et al. (2020)](https://arxiv.org/abs/2001.06782); [Liu et al. (2021)](https://arxiv.org/pdf/2110.14048.pdf) we report the test performance averaged
12 | over the last 10 epochs. Note that this averaging is not handled in the code and need to be applied by the user.
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Aviv Navon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from io import open
2 | from os import path
3 |
4 | from setuptools import find_packages, setup
5 |
6 | here = path.abspath(path.dirname(__file__))
7 |
8 | # get the long description from the README.md file
9 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
10 | long_description = f.read()
11 |
12 |
13 | # get reqs
14 | def requirements():
15 | list_requirements = []
16 | with open("requirements.txt") as f:
17 | for line in f:
18 | list_requirements.append(line.rstrip())
19 | return list_requirements
20 |
21 |
22 | setup(
23 | name="nashmtl",
24 | version="1.0.0", # Required
25 | description="Nash-MTL", # Optional
26 | long_description="", # Optional
27 | long_description_content_type="text/markdown", # Optional (see note above)
28 | url="", # Optional
29 | author="", # Optional
30 | author_email="", # Optional
31 | packages=find_packages(exclude=["contrib", "docs", "tests"]),
32 | # packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required
33 | python_requires=">=3.6",
34 | install_requires=requirements(), # Optional
35 | )
36 |
--------------------------------------------------------------------------------
/experiments/toy/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import numpy as np
3 | import seaborn as sns
4 | import torch
5 | from matplotlib import pyplot as plt
6 |
7 | from experiments.toy.problem import Toy
8 |
9 |
10 | # plotting utils
11 | def plot_2d_pareto(trajectories: dict, scale):
12 | """Adaptation of code from: https://github.com/Cranial-XIX/CAGrad"""
13 | fig, ax = plt.subplots(figsize=(6, 5))
14 |
15 | F = Toy(scale=scale)
16 |
17 | losses = []
18 | for res in trajectories.values():
19 | losses.append(F.batch_forward(torch.from_numpy(res["traj"])))
20 |
21 | yy = -8.3552
22 | x = np.linspace(-7, 7, 1000)
23 |
24 | inpt = np.stack((x, [yy] * len(x))).T
25 | Xs = torch.from_numpy(inpt).double()
26 |
27 | Ys = F.batch_forward(Xs)
28 | ax.plot(
29 | Ys.numpy()[:, 0],
30 | Ys.numpy()[:, 1],
31 | "-",
32 | linewidth=8,
33 | color="#72727A",
34 | label="Pareto Front",
35 | ) # Pareto front
36 |
37 | for i, tt in enumerate(losses):
38 | ax.scatter(
39 | tt[0, 0],
40 | tt[0, 1],
41 | color="k",
42 | s=150,
43 | zorder=10,
44 | label="Initial Point" if i == 0 else None,
45 | )
46 | colors = matplotlib.cm.magma_r(np.linspace(0.1, 0.6, tt.shape[0]))
47 | ax.scatter(tt[:, 0], tt[:, 1], color=colors, s=5, zorder=9)
48 |
49 | sns.despine()
50 | ax.set_xlabel(r"$\ell_1$", size=30)
51 | ax.set_ylabel(r"$\ell_2$", size=30)
52 | ax.xaxis.set_label_coords(1.015, -0.03)
53 | ax.yaxis.set_label_coords(-0.01, 1.01)
54 |
55 | for tick in ax.xaxis.get_major_ticks():
56 | tick.label.set_fontsize(20)
57 | for tick in ax.yaxis.get_major_ticks():
58 | tick.label.set_fontsize(20)
59 |
60 | plt.tight_layout()
61 |
62 | legend = ax.legend(
63 | loc=2, bbox_to_anchor=(-0.15, 1.3), frameon=False, fontsize=20, ncol=2
64 | )
65 |
66 | return ax, fig, legend
67 |
--------------------------------------------------------------------------------
/experiments/toy/problem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | LOWER = 0.000005
5 |
6 |
7 | class Toy(nn.Module):
8 | def __init__(self, scale=1.0, scale_both_losses=1.0):
9 | super(Toy, self).__init__()
10 | self.centers = torch.Tensor([[-3.0, 0], [3.0, 0]])
11 | self.scale = scale
12 | self.scale_both_losses = scale_both_losses
13 |
14 | def forward(self, x, compute_grad=False):
15 | x1 = x[0]
16 | x2 = x[1]
17 |
18 | f1 = torch.clamp((0.5 * (-x1 - 7) - torch.tanh(-x2)).abs(), LOWER).log() + 6
19 | f2 = torch.clamp((0.5 * (-x1 + 3) + torch.tanh(-x2) + 2).abs(), LOWER).log() + 6
20 | c1 = torch.clamp(torch.tanh(x2 * 0.5), 0)
21 |
22 | f1_sq = ((-x1 + 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20
23 | f2_sq = ((-x1 - 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20
24 | c2 = torch.clamp(torch.tanh(-x2 * 0.5), 0)
25 |
26 | f1 = f1 * c1 + f1_sq * c2
27 | f1 *= self.scale
28 | f2 = f2 * c1 + f2_sq * c2
29 |
30 | f = torch.stack([f1, f2]) * self.scale_both_losses
31 | if compute_grad:
32 | g11 = torch.autograd.grad(f1, x1, retain_graph=True)[0].item()
33 | g12 = torch.autograd.grad(f1, x2, retain_graph=True)[0].item()
34 | g21 = torch.autograd.grad(f2, x1, retain_graph=True)[0].item()
35 | g22 = torch.autograd.grad(f2, x2, retain_graph=True)[0].item()
36 | g = torch.Tensor([[g11, g21], [g12, g22]])
37 | return f, g
38 | else:
39 | return f
40 |
41 | def batch_forward(self, x):
42 | x1 = x[:, 0]
43 | x2 = x[:, 1]
44 |
45 | f1 = torch.clamp((0.5 * (-x1 - 7) - torch.tanh(-x2)).abs(), LOWER).log() + 6
46 | f2 = torch.clamp((0.5 * (-x1 + 3) + torch.tanh(-x2) + 2).abs(), LOWER).log() + 6
47 | c1 = torch.clamp(torch.tanh(x2 * 0.5), 0)
48 |
49 | f1_sq = ((-x1 + 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20
50 | f2_sq = ((-x1 - 7).pow(2) + 0.1 * (-x2 - 8).pow(2)) / 10 - 20
51 | c2 = torch.clamp(torch.tanh(-x2 * 0.5), 0)
52 |
53 | f1 = f1 * c1 + f1_sq * c2
54 | f1 *= self.scale
55 | f2 = f2 * c1 + f2_sq * c2
56 |
57 | f = torch.cat([f1.view(-1, 1), f2.view(-1, 1)], -1) * self.scale_both_losses
58 | return f
59 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/models.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import Iterator
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.nn import GRU, Linear, ReLU, Sequential
7 | from torch_geometric.nn import DimeNet, NNConv, Set2Set, radius_graph
8 |
9 |
10 | class Net(torch.nn.Module):
11 | def __init__(self, n_tasks, num_features=11, dim=64):
12 | super().__init__()
13 | self.n_tasks = n_tasks
14 | self.dim = dim
15 | self.lin0 = torch.nn.Linear(num_features, dim)
16 |
17 | nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
18 | self.conv = NNConv(dim, dim, nn, aggr="mean")
19 | self.gru = GRU(dim, dim)
20 |
21 | self.set2set = Set2Set(dim, processing_steps=3)
22 | self.lin1 = torch.nn.Linear(2 * dim, dim)
23 |
24 | self._init_task_heads()
25 |
26 | def _init_task_heads(self):
27 | for i in range(self.n_tasks):
28 | setattr(self, f"head_{i}", torch.nn.Linear(self.dim, 1))
29 | self.task_specific = torch.nn.ModuleList(
30 | [getattr(self, f"head_{i}") for i in range(self.n_tasks)]
31 | )
32 |
33 | def forward(self, data, return_representation=False):
34 | out = F.relu(self.lin0(data.x))
35 | h = out.unsqueeze(0)
36 |
37 | for i in range(3):
38 | m = F.relu(self.conv(out, data.edge_index, data.edge_attr))
39 | out, h = self.gru(m.unsqueeze(0), h)
40 | out = out.squeeze(0)
41 |
42 | out = self.set2set(out, data.batch)
43 | features = F.relu(self.lin1(out))
44 | logits = torch.cat(
45 | [getattr(self, f"head_{i}")(features) for i in range(self.n_tasks)], dim=1
46 | )
47 | if return_representation:
48 | return logits, features
49 | return logits
50 |
51 | def shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
52 | return chain(
53 | self.lin0.parameters(),
54 | self.conv.parameters(),
55 | self.gru.parameters(),
56 | self.set2set.parameters(),
57 | self.lin1.parameters(),
58 | )
59 |
60 | def task_specific_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
61 | return self.task_specific.parameters()
62 |
63 | def last_shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
64 | return self.lin1.parameters()
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .DS_Store
131 | .idea/
132 | wandb/
133 | datasets/
134 | outputs/
135 | data/
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.utils import remove_self_loops
4 |
5 |
6 | class MyTransform(object):
7 | def __init__(self, target: list = None):
8 | if target is None:
9 | target = torch.tensor([0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]) # removing 4
10 | else:
11 | target = torch.tensor(target)
12 | self.target = target
13 |
14 | def __call__(self, data):
15 | # Specify target.
16 | data.y = data.y[:, self.target]
17 | return data
18 |
19 |
20 | class Complete(object):
21 | def __call__(self, data):
22 | device = data.edge_index.device
23 |
24 | row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
25 | col = torch.arange(data.num_nodes, dtype=torch.long, device=device)
26 |
27 | row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
28 | col = col.repeat(data.num_nodes)
29 | edge_index = torch.stack([row, col], dim=0)
30 |
31 | edge_attr = None
32 | if data.edge_attr is not None:
33 | idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
34 | size = list(data.edge_attr.size())
35 | size[0] = data.num_nodes * data.num_nodes
36 | edge_attr = data.edge_attr.new_zeros(size)
37 | edge_attr[idx] = data.edge_attr
38 |
39 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
40 | data.edge_attr = edge_attr
41 | data.edge_index = edge_index
42 |
43 | return data
44 |
45 |
46 | qm9_target_dict = {
47 | 0: "mu",
48 | 1: "alpha",
49 | 2: "homo",
50 | 3: "lumo",
51 | 5: "r2",
52 | 6: "zpve",
53 | 7: "U0",
54 | 8: "U",
55 | 9: "H",
56 | 10: "G",
57 | 11: "Cv",
58 | }
59 |
60 | # for \Delta_m calculations
61 | # -------------------------
62 | # DimeNet uses the atomization energy for targets U0, U, H, and G.
63 | target_idx = [0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]
64 |
65 | # Report meV instead of eV.
66 | multiply_indx = [2, 3, 5, 6, 7, 8, 9]
67 |
68 | n_tasks = len(target_idx)
69 |
70 | # stl results
71 | BASE = np.array(
72 | [
73 | 0.0671,
74 | 0.1814,
75 | 60.576,
76 | 53.915,
77 | 0.5027,
78 | 4.539,
79 | 58.838,
80 | 64.244,
81 | 63.852,
82 | 66.223,
83 | 0.07212,
84 | ]
85 | )
86 |
87 | SIGN = np.array([0] * n_tasks)
88 | KK = np.ones(n_tasks) * -1
89 |
90 |
91 | def delta_fn(a):
92 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # *100 for percentage
93 |
--------------------------------------------------------------------------------
/experiments/nyuv2/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ConfMatrix(object):
6 | def __init__(self, num_classes):
7 | self.num_classes = num_classes
8 | self.mat = None
9 |
10 | def update(self, pred, target):
11 | n = self.num_classes
12 | if self.mat is None:
13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
14 | with torch.no_grad():
15 | k = (target >= 0) & (target < n)
16 | inds = n * target[k].to(torch.int64) + pred[k]
17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
18 |
19 | def get_metrics(self):
20 | h = self.mat.float()
21 | acc = torch.diag(h).sum() / h.sum()
22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy()
24 |
25 |
26 | def depth_error(x_pred, x_output):
27 | device = x_pred.device
28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
29 | x_pred_true = x_pred.masked_select(binary_mask)
30 | x_output_true = x_output.masked_select(binary_mask)
31 | abs_err = torch.abs(x_pred_true - x_output_true)
32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
33 | return (
34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
35 | ).item(), (
36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
37 | ).item()
38 |
39 |
40 | def normal_error(x_pred, x_output):
41 | binary_mask = torch.sum(x_output, dim=1) != 0
42 | error = (
43 | torch.acos(
44 | torch.clamp(
45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1
46 | )
47 | )
48 | .detach()
49 | .cpu()
50 | .numpy()
51 | )
52 | error = np.degrees(error)
53 | return (
54 | np.mean(error),
55 | np.median(error),
56 | np.mean(error < 11.25),
57 | np.mean(error < 22.5),
58 | np.mean(error < 30),
59 | )
60 |
61 |
62 | # for calculating \Delta_m
63 | delta_stats = [
64 | "mean iou",
65 | "pix acc",
66 | "abs err",
67 | "rel err",
68 | "mean",
69 | "median",
70 | "<11.25",
71 | "<22.5",
72 | "<30",
73 | ]
74 | BASE = np.array(
75 | [0.3830, 0.6376, 0.6754, 0.2780, 25.01, 19.21, 0.3014, 0.5720, 0.6915]
76 | ) # base results from CAGrad
77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1])
78 | KK = np.ones(9) * -1
79 |
80 |
81 | def delta_fn(a):
82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nash-MTL
2 |
3 | Official implementation of _"Multi-Task Learning as a Bargaining Game"_.
4 |
5 |
6 |
7 |
8 |
9 | ## Setup environment
10 |
11 | ```bash
12 | conda create -n nashmtl python=3.9.7
13 | conda activate nashmtl
14 | conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=10.2 -c pytorch
15 | conda install pyg -c pyg -c conda-forge
16 | ```
17 |
18 | Install the repo:
19 |
20 | ```bash
21 | git clone https://github.com/AvivNavon/nash-mtl.git
22 | cd nash-mtl
23 | pip install -e .
24 | ```
25 |
26 | ## Run experiment
27 |
28 | To run experiments:
29 |
30 | ```bash
31 | cd experiment/
32 | python trainer.py --method=nashmtl
33 | ```
34 | Follow instruction on the experiment README file for more information regarding, e.g., datasets.
35 |
36 | Here `` is one of `[toy, quantum_chemistry, nyuv2]`. You can also replace `nashmtl` with on of the following MTL methods.
37 |
38 | We also support experiment tracking with **[Weights & Biases](https://wandb.ai/site)** with two additional parameters:
39 |
40 | ```bash
41 | python trainer.py --method=nashmtl --wandb_project= --wandb_entity=
42 | ```
43 |
44 | ## MTL methods
45 |
46 | We support the following MTL methods with a unified API. To run experiment with MTL method `X` simply run:
47 | ```bash
48 | python trainer.py --method=X
49 | ```
50 |
51 | | Method (code name) | Paper (notes) |
52 | | :---: | :---: |
53 | | Nash-MTL (`nashmtl`) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) |
54 | | CAGrad (`cagrad`) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) |
55 | | PCGrad (`pcgrad`) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/abs/2001.06782) |
56 | | IMTL-G (`imtl`) | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) |
57 | | MGDA (`mgda`) | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/abs/1810.04650) |
58 | | DWA (`dwa`) | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/abs/1803.10704) |
59 | | Uncertainty weighting (`uw`) | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) |
60 | | Linear scalarization (`ls`) | - (equal weighting) |
61 | | Scale-invariant baseline (`scaleinvls`) | - (see Nash-MTL paper for details) |
62 | | Random Loss Weighting (`rlw`) | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) |
63 |
64 |
65 |
66 |
67 | ## Citation
68 |
69 | If you find `Nash-MTL` to be useful in your own research, please consider citing the following paper:
70 |
71 | ```bib
72 | @article{navon2022multi,
73 | title={Multi-Task Learning as a Bargaining Game},
74 | author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
75 | journal={arXiv preprint arXiv:2202.01017},
76 | year={2022}
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import random
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from methods import METHODS
11 |
12 |
13 | def str_to_list(string):
14 | return [float(s) for s in string.split(",")]
15 |
16 |
17 | def str_or_float(value):
18 | try:
19 | return float(value)
20 | except:
21 | return value
22 |
23 |
24 | def str2bool(v):
25 | if isinstance(v, bool):
26 | return v
27 | if v.lower() in ("yes", "true", "t", "y", "1"):
28 | return True
29 | elif v.lower() in ("no", "false", "f", "n", "0"):
30 | return False
31 | else:
32 | raise argparse.ArgumentTypeError("Boolean value expected.")
33 |
34 |
35 | common_parser = argparse.ArgumentParser(add_help=False)
36 | common_parser.add_argument("--data-path", type=Path, help="path to data")
37 | common_parser.add_argument("--n-epochs", type=int, default=300)
38 | common_parser.add_argument("--batch-size", type=int, default=120, help="batch size")
39 | common_parser.add_argument(
40 | "--method", type=str, choices=list(METHODS.keys()), help="MTL weight method"
41 | )
42 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
43 | common_parser.add_argument(
44 | "--method-params-lr",
45 | type=float,
46 | default=0.025,
47 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting",
48 | )
49 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
50 | common_parser.add_argument("--seed", type=int, default=42, help="seed value")
51 | # NashMTL
52 | common_parser.add_argument(
53 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations"
54 | )
55 | common_parser.add_argument(
56 | "--update-weights-every",
57 | type=int,
58 | default=1,
59 | help="update task weights every x iterations.",
60 | )
61 | # stl
62 | common_parser.add_argument(
63 | "--main-task",
64 | type=int,
65 | default=0,
66 | help="main task for stl. Ignored if method != stl",
67 | )
68 | # cagrad
69 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.")
70 | # dwa
71 | # dwa
72 | common_parser.add_argument(
73 | "--dwa-temp",
74 | type=float,
75 | default=2.0,
76 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.",
77 | )
78 |
79 |
80 | def count_parameters(model):
81 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
82 |
83 |
84 | def set_logger():
85 | logging.basicConfig(
86 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
87 | level=logging.INFO,
88 | )
89 |
90 |
91 | def set_seed(seed):
92 | """for reproducibility
93 | :param seed:
94 | :return:
95 | """
96 | np.random.seed(seed)
97 | random.seed(seed)
98 |
99 | torch.manual_seed(seed)
100 | if torch.cuda.is_available():
101 | torch.cuda.manual_seed(seed)
102 | torch.cuda.manual_seed_all(seed)
103 |
104 | torch.backends.cudnn.enabled = True
105 | torch.backends.cudnn.benchmark = False
106 | torch.backends.cudnn.deterministic = True
107 |
108 |
109 | def get_device(no_cuda=False, gpus="0"):
110 | return torch.device(
111 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu"
112 | )
113 |
114 |
115 | def extract_weight_method_parameters_from_args(args):
116 | weight_methods_parameters = defaultdict(dict)
117 | weight_methods_parameters.update(
118 | dict(
119 | nashmtl=dict(
120 | update_weights_every=args.update_weights_every,
121 | optim_niter=args.nashmtl_optim_niter,
122 | ),
123 | stl=dict(main_task=args.main_task),
124 | cagrad=dict(c=args.c),
125 | dwa=dict(temp=args.dwa_temp),
126 | )
127 | )
128 | return weight_methods_parameters
129 |
--------------------------------------------------------------------------------
/experiments/toy/trainer.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import logging
3 | from argparse import ArgumentParser
4 | from pathlib import Path
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from experiments.toy.problem import Toy
12 | from experiments.toy.utils import plot_2d_pareto
13 | from experiments.utils import (
14 | common_parser,
15 | extract_weight_method_parameters_from_args,
16 | set_logger,
17 | )
18 | from methods.weight_methods import WeightMethods
19 |
20 | set_logger()
21 |
22 |
23 | def main(method_type, device, n_iter, scale):
24 | weight_methods_parameters = extract_weight_method_parameters_from_args(args)
25 | n_tasks = 2
26 |
27 | F = Toy(scale=scale)
28 |
29 | all_traj = dict()
30 |
31 | # the initial positions
32 | inits = [
33 | torch.Tensor([-8.5, 7.5]),
34 | torch.Tensor([0.0, 0.0]),
35 | torch.Tensor([9.0, 9.0]),
36 | torch.Tensor([-7.5, -0.5]),
37 | torch.Tensor([9, -1.0]),
38 | ]
39 |
40 | for i, init in enumerate(inits):
41 | traj = []
42 | x = init.clone()
43 | x.requires_grad = True
44 | x = x.to(device)
45 |
46 | method = WeightMethods(
47 | method=method_type,
48 | device=device,
49 | n_tasks=n_tasks,
50 | **weight_methods_parameters[method_type],
51 | )
52 |
53 | optimizer = torch.optim.Adam(
54 | [
55 | dict(params=[x], lr=1e-3),
56 | dict(params=method.parameters(), lr=args.method_params_lr),
57 | ],
58 | )
59 |
60 | for _ in tqdm(range(n_iter)):
61 | traj.append(x.cpu().detach().numpy().copy())
62 |
63 | optimizer.zero_grad()
64 | f = F(x, False)
65 | _ = method.backward(
66 | losses=f,
67 | shared_parameters=(x,),
68 | task_specific_parameters=None,
69 | last_shared_parameters=None,
70 | representation=None,
71 | )
72 | optimizer.step()
73 |
74 | all_traj[i] = dict(init=init.cpu().detach().numpy().copy(), traj=np.array(traj))
75 |
76 | return all_traj
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = ArgumentParser(
81 | "Toy example (modification of the one in CAGrad)", parents=[common_parser]
82 | )
83 | parser.set_defaults(n_epochs=35000, method="nashmtl", data_path=None)
84 | parser.add_argument(
85 | "--scale", default=1e-1, type=float, help="scale for first loss"
86 | )
87 | parser.add_argument("--out-path", default="outputs", type=Path, help="output path")
88 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.")
89 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.")
90 | args = parser.parse_args()
91 |
92 | if args.wandb_project is not None:
93 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args)
94 |
95 | out_path = args.out_path
96 | out_path.mkdir(parents=True, exist_ok=True)
97 | logging.info(f"Logs and plots are saved in: {out_path.as_posix()}")
98 |
99 | device = torch.device("cpu")
100 | all_traj = main(
101 | method_type=args.method, device=device, n_iter=args.n_epochs, scale=args.scale
102 | )
103 |
104 | # plot
105 | ax, fig, legend = plot_2d_pareto(trajectories=all_traj, scale=args.scale)
106 |
107 | title_map = {
108 | "nashmtl": "Nash-MTL",
109 | "cagrad": "CAGrad",
110 | "mgda": "MGDA",
111 | "pcgrad": "PCGrad",
112 | "ls": "LS",
113 | }
114 | ax.set_title(title_map[args.method], fontsize=25)
115 | plt.savefig(
116 | out_path / f"{args.method}.png",
117 | bbox_extra_artists=(legend,),
118 | bbox_inches="tight",
119 | facecolor="white",
120 | )
121 | plt.close()
122 |
123 | if wandb.run is not None:
124 | wandb.log({"Pareto Front": wandb.Image((out_path / f"{args.method}.png").as_posix())})
125 |
126 | wandb.finish()
127 |
--------------------------------------------------------------------------------
/experiments/nyuv2/data.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data.dataset import Dataset
9 |
10 | """Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py
11 |
12 | """
13 |
14 |
15 | class RandomScaleCrop(object):
16 | """
17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
18 | """
19 |
20 | def __init__(self, scale=[1.0, 1.2, 1.5]):
21 | self.scale = scale
22 |
23 | def __call__(self, img, label, depth, normal):
24 | height, width = img.shape[-2:]
25 | sc = self.scale[random.randint(0, len(self.scale) - 1)]
26 | h, w = int(height / sc), int(width / sc)
27 | i = random.randint(0, height - h)
28 | j = random.randint(0, width - w)
29 | img_ = F.interpolate(
30 | img[None, :, i : i + h, j : j + w],
31 | size=(height, width),
32 | mode="bilinear",
33 | align_corners=True,
34 | ).squeeze(0)
35 | label_ = (
36 | F.interpolate(
37 | label[None, None, i : i + h, j : j + w],
38 | size=(height, width),
39 | mode="nearest",
40 | )
41 | .squeeze(0)
42 | .squeeze(0)
43 | )
44 | depth_ = F.interpolate(
45 | depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest"
46 | ).squeeze(0)
47 | normal_ = F.interpolate(
48 | normal[None, :, i : i + h, j : j + w],
49 | size=(height, width),
50 | mode="bilinear",
51 | align_corners=True,
52 | ).squeeze(0)
53 | return img_, label_, depth_ / sc, normal_
54 |
55 |
56 | class NYUv2(Dataset):
57 | """
58 | We could further improve the performance with the data augmentation of NYUv2 defined in:
59 | [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
60 | [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
61 | [3] Mti-net: Multiscale task interaction networks for multi-task learning
62 | 1. Random scale in a selected raio 1.0, 1.2, and 1.5.
63 | 2. Random horizontal flip.
64 | Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
65 | """
66 |
67 | def __init__(self, root, train=True, augmentation=False):
68 | self.train = train
69 | self.root = os.path.expanduser(root)
70 | self.augmentation = augmentation
71 |
72 | # read the data file
73 | if train:
74 | self.data_path = root + "/train"
75 | else:
76 | self.data_path = root + "/val"
77 |
78 | # calculate data length
79 | self.data_len = len(
80 | fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
81 | )
82 |
83 | def __getitem__(self, index):
84 | # load data from the pre-processed npy files
85 | image = torch.from_numpy(
86 | np.moveaxis(
87 | np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0
88 | )
89 | )
90 | semantic = torch.from_numpy(
91 | np.load(self.data_path + "/label/{:d}.npy".format(index))
92 | )
93 | depth = torch.from_numpy(
94 | np.moveaxis(
95 | np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0
96 | )
97 | )
98 | normal = torch.from_numpy(
99 | np.moveaxis(
100 | np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0
101 | )
102 | )
103 |
104 | # apply data augmentation if required
105 | if self.augmentation:
106 | image, semantic, depth, normal = RandomScaleCrop()(
107 | image, semantic, depth, normal
108 | )
109 | if torch.rand(1) < 0.5:
110 | image = torch.flip(image, dims=[2])
111 | semantic = torch.flip(semantic, dims=[1])
112 | depth = torch.flip(depth, dims=[2])
113 | normal = torch.flip(normal, dims=[2])
114 | normal[0, :, :] = -normal[0, :, :]
115 |
116 | return image.float(), semantic.float(), depth.float(), normal.float()
117 |
118 | def __len__(self):
119 | return self.data_len
120 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/trainer.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import torch_geometric.transforms as T
7 | from torch_geometric.datasets import QM9
8 | from torch_geometric.loader import DataLoader
9 | from tqdm import trange
10 | import wandb
11 |
12 | from experiments.quantum_chemistry.models import Net
13 | from experiments.quantum_chemistry.utils import (
14 | Complete,
15 | MyTransform,
16 | delta_fn,
17 | multiply_indx,
18 | )
19 | from experiments.quantum_chemistry.utils import target_idx as targets
20 | from experiments.utils import (
21 | common_parser,
22 | extract_weight_method_parameters_from_args,
23 | get_device,
24 | set_logger,
25 | set_seed,
26 | str2bool,
27 | )
28 | from methods.weight_methods import WeightMethods
29 |
30 | set_logger()
31 |
32 |
33 | @torch.no_grad()
34 | def evaluate(model, loader, std, scale_target):
35 | model.eval()
36 | data_size = 0.0
37 | task_losses = 0.0
38 | for i, data in enumerate(loader):
39 | data = data.to(device)
40 | out = model(data)
41 | if scale_target:
42 | task_losses += F.l1_loss(
43 | out * std.to(device), data.y * std.to(device), reduction="none"
44 | ).sum(
45 | 0
46 | ) # MAE
47 | else:
48 | task_losses += F.l1_loss(out, data.y, reduction="none").sum(0) # MAE
49 | data_size += len(data.y)
50 |
51 | model.train()
52 |
53 | avg_task_losses = task_losses / data_size
54 |
55 | # Report meV instead of eV.
56 | avg_task_losses = avg_task_losses.detach().cpu().numpy()
57 | avg_task_losses[multiply_indx] *= 1000
58 |
59 | delta_m = delta_fn(avg_task_losses)
60 | return dict(
61 | avg_loss=avg_task_losses.mean(),
62 | avg_task_losses=avg_task_losses,
63 | delta_m=delta_m,
64 | )
65 |
66 |
67 | def main(
68 | data_path: str,
69 | batch_size: int,
70 | device: torch.device,
71 | method: str,
72 | weight_method_params: dict,
73 | lr: float,
74 | method_params_lr: float,
75 | n_epochs: int,
76 | targets: list = None,
77 | scale_target: bool = True,
78 | main_task: int = None,
79 | ):
80 | dim = 64
81 | model = Net(n_tasks=len(targets), num_features=11, dim=dim).to(device)
82 |
83 | transform = T.Compose([MyTransform(targets), Complete(), T.Distance(norm=False)])
84 | dataset = QM9(data_path, transform=transform).shuffle()
85 |
86 | # Split datasets.
87 | test_dataset = dataset[:10000]
88 | val_dataset = dataset[10000:20000]
89 | train_dataset = dataset[20000:]
90 |
91 | std = None
92 | if scale_target:
93 | mean = train_dataset.data.y[:, targets].mean(dim=0, keepdim=True)
94 | std = train_dataset.data.y[:, targets].std(dim=0, keepdim=True)
95 |
96 | dataset.data.y[:, targets] = (dataset.data.y[:, targets] - mean) / std
97 |
98 | test_loader = DataLoader(
99 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
100 | )
101 | val_loader = DataLoader(
102 | val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
103 | )
104 | train_loader = DataLoader(
105 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
106 | )
107 |
108 | weight_method = WeightMethods(
109 | method,
110 | n_tasks=len(targets),
111 | device=device,
112 | **weight_method_params[method],
113 | )
114 |
115 | optimizer = torch.optim.Adam(
116 | [
117 | dict(params=model.parameters(), lr=lr),
118 | dict(params=weight_method.parameters(), lr=method_params_lr),
119 | ],
120 | )
121 |
122 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
123 | optimizer, mode="min", factor=0.7, patience=5, min_lr=0.00001
124 | )
125 |
126 | epoch_iterator = trange(n_epochs)
127 |
128 | best_val = np.inf
129 | best_test = np.inf
130 | best_test_delta = np.inf
131 | best_val_delta = np.inf
132 | best_test_results = None
133 |
134 | for epoch in epoch_iterator:
135 | lr = scheduler.optimizer.param_groups[0]["lr"]
136 | for j, data in enumerate(train_loader):
137 | model.train()
138 |
139 | data = data.to(device)
140 | optimizer.zero_grad()
141 |
142 | out, features = model(data, return_representation=True)
143 |
144 | losses = F.mse_loss(out, data.y, reduction="none").mean(0)
145 |
146 | loss, extra_outputs = weight_method.backward(
147 | losses=losses,
148 | shared_parameters=list(model.shared_parameters()),
149 | task_specific_parameters=list(model.task_specific_parameters()),
150 | last_shared_parameters=list(model.last_shared_parameters()),
151 | representation=features,
152 | )
153 |
154 | optimizer.step()
155 |
156 |
157 |
158 | val_loss_dict = evaluate(model, val_loader, std=std, scale_target=scale_target)
159 | test_loss_dict = evaluate(
160 | model, test_loader, std=std, scale_target=scale_target
161 | )
162 | val_loss = val_loss_dict["avg_loss"]
163 | val_delta = val_loss_dict["delta_m"]
164 | test_loss = test_loss_dict["avg_loss"]
165 | test_delta = test_loss_dict["delta_m"]
166 |
167 | if method == "stl":
168 | best_val_criteria = val_loss_dict["avg_task_losses"][main_task] <= best_val
169 | else:
170 | best_val_criteria = val_delta <= best_val_delta
171 |
172 | if best_val_criteria:
173 | best_val = val_loss
174 | best_test = test_loss
175 | best_test_results = test_loss_dict
176 | best_val_delta = val_delta
177 | best_test_delta = test_delta
178 |
179 | # for logger
180 | epoch_iterator.set_description(
181 | f"epoch {epoch} | lr={lr} | train loss {losses.mean().item():.3f} | val loss: {val_loss:.3f} | "
182 | f"test loss: {test_loss:.3f} | best test loss {best_test:.3f} | best_test_delta {best_test_delta:.3f}"
183 | )
184 |
185 | if wandb.run is not None:
186 | wandb.log({"Learning Rate": lr}, step=epoch)
187 | wandb.log({"Train Loss": losses.mean().item()}, step=epoch)
188 | wandb.log({"Val Loss": val_loss}, step=epoch)
189 | wandb.log({"Val Delta": val_delta}, step=epoch)
190 | wandb.log({"Test Loss": test_loss}, step=epoch)
191 | wandb.log({"Test Delta": test_delta}, step=epoch)
192 | wandb.log({"Best Test Loss": best_test}, step=epoch)
193 | wandb.log({"Best Test Delta": best_test_delta}, step=epoch)
194 |
195 | scheduler.step(
196 | val_loss_dict["avg_task_losses"][main_task]
197 | if method == "stl"
198 | else val_delta
199 | )
200 |
201 |
202 | if __name__ == "__main__":
203 | parser = ArgumentParser("QM9", parents=[common_parser])
204 | parser.set_defaults(
205 | data_path="dataset",
206 | lr=1e-3,
207 | n_epochs=300,
208 | batch_size=120,
209 | method="nashmtl",
210 | )
211 | parser.add_argument("--scale-y", default=True, type=str2bool)
212 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.")
213 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.")
214 | args = parser.parse_args()
215 |
216 | # set seed
217 | set_seed(args.seed)
218 |
219 | if args.wandb_project is not None:
220 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args)
221 |
222 | weight_method_params = extract_weight_method_parameters_from_args(args)
223 |
224 | device = get_device(gpus=args.gpu)
225 | main(
226 | data_path=args.data_path,
227 | batch_size=args.batch_size,
228 | device=device,
229 | method=args.method,
230 | weight_method_params=weight_method_params,
231 | lr=args.lr,
232 | method_params_lr=args.method_params_lr,
233 | n_epochs=args.n_epochs,
234 | targets=targets,
235 | scale_target=args.scale_y,
236 | main_task=args.main_task,
237 | )
238 |
239 | if wandb.run is not None:
240 | wandb.finish()
241 |
--------------------------------------------------------------------------------
/methods/min_norm_solvers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | # This code is from
6 | # Multi-Task Learning as Multi-Objective Optimization
7 | # Ozan Sener, Vladlen Koltun
8 | # Neural Information Processing Systems (NeurIPS) 2018
9 | # https://github.com/intel-isl/MultiObjectiveOptimization
10 | class MinNormSolver:
11 | MAX_ITER = 250
12 | STOP_CRIT = 1e-5
13 |
14 | @staticmethod
15 | def _min_norm_element_from2(v1v1, v1v2, v2v2):
16 | """
17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
18 | d is the distance (objective) optimzed
19 | v1v1 =
20 | v1v2 =
21 | v2v2 =
22 | """
23 | if v1v2 >= v1v1:
24 | # Case: Fig 1, third column
25 | gamma = 0.999
26 | cost = v1v1
27 | return gamma, cost
28 | if v1v2 >= v2v2:
29 | # Case: Fig 1, first column
30 | gamma = 0.001
31 | cost = v2v2
32 | return gamma, cost
33 | # Case: Fig 1, second column
34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
35 | cost = v2v2 + gamma * (v1v2 - v2v2)
36 | return gamma, cost
37 |
38 | @staticmethod
39 | def _min_norm_2d(vecs, dps):
40 | """
41 | Find the minimum norm solution as combination of two points
42 | This is correct only in 2D
43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
44 | """
45 | dmin = 1e8
46 | for i in range(len(vecs)):
47 | for j in range(i + 1, len(vecs)):
48 | if (i, j) not in dps:
49 | dps[(i, j)] = 0.0
50 | for k in range(len(vecs[i])):
51 | dps[(i, j)] += torch.dot(
52 | vecs[i][k], vecs[j][k]
53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0]
54 | dps[(j, i)] = dps[(i, j)]
55 | if (i, i) not in dps:
56 | dps[(i, i)] = 0.0
57 | for k in range(len(vecs[i])):
58 | dps[(i, i)] += torch.dot(
59 | vecs[i][k], vecs[i][k]
60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0]
61 | if (j, j) not in dps:
62 | dps[(j, j)] = 0.0
63 | for k in range(len(vecs[i])):
64 | dps[(j, j)] += torch.dot(
65 | vecs[j][k], vecs[j][k]
66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0]
67 | c, d = MinNormSolver._min_norm_element_from2(
68 | dps[(i, i)], dps[(i, j)], dps[(j, j)]
69 | )
70 | if d < dmin:
71 | dmin = d
72 | sol = [(i, j), c, d]
73 | return sol, dps
74 |
75 | @staticmethod
76 | def _projection2simplex(y):
77 | """
78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
79 | """
80 | m = len(y)
81 | sorted_y = np.flip(np.sort(y), axis=0)
82 | tmpsum = 0.0
83 | tmax_f = (np.sum(y) - 1.0) / m
84 | for i in range(m - 1):
85 | tmpsum += sorted_y[i]
86 | tmax = (tmpsum - 1) / (i + 1.0)
87 | if tmax > sorted_y[i + 1]:
88 | tmax_f = tmax
89 | break
90 | return np.maximum(y - tmax_f, np.zeros(y.shape))
91 |
92 | @staticmethod
93 | def _next_point(cur_val, grad, n):
94 | proj_grad = grad - (np.sum(grad) / n)
95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
97 |
98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7)
99 | t = 1
100 | if len(tm1[tm1 > 1e-7]) > 0:
101 | t = np.min(tm1[tm1 > 1e-7])
102 | if len(tm2[tm2 > 1e-7]) > 0:
103 | t = min(t, np.min(tm2[tm2 > 1e-7]))
104 |
105 | next_point = proj_grad * t + cur_val
106 | next_point = MinNormSolver._projection2simplex(next_point)
107 | return next_point
108 |
109 | @staticmethod
110 | def find_min_norm_element(vecs):
111 | """
112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
116 | """
117 | # Solution lying at the combination of two points
118 | dps = {}
119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
120 |
121 | n = len(vecs)
122 | sol_vec = np.zeros(n)
123 | sol_vec[init_sol[0][0]] = init_sol[1]
124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
125 |
126 | if n < 3:
127 | # This is optimal for n=2, so return the solution
128 | return sol_vec, init_sol[2]
129 |
130 | iter_count = 0
131 |
132 | grad_mat = np.zeros((n, n))
133 | for i in range(n):
134 | for j in range(n):
135 | grad_mat[i, j] = dps[(i, j)]
136 |
137 | while iter_count < MinNormSolver.MAX_ITER:
138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
140 | # Re-compute the inner products for line search
141 | v1v1 = 0.0
142 | v1v2 = 0.0
143 | v2v2 = 0.0
144 | for i in range(n):
145 | for j in range(n):
146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point
151 | change = new_sol_vec - sol_vec
152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
153 | return sol_vec, nd
154 | sol_vec = new_sol_vec
155 |
156 | @staticmethod
157 | def find_min_norm_element_FW(vecs):
158 | """
159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
163 | """
164 | # Solution lying at the combination of two points
165 | dps = {}
166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
167 |
168 | n = len(vecs)
169 | sol_vec = np.zeros(n)
170 | sol_vec[init_sol[0][0]] = init_sol[1]
171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
172 |
173 | if n < 3:
174 | # This is optimal for n=2, so return the solution
175 | return sol_vec, init_sol[2]
176 |
177 | iter_count = 0
178 |
179 | grad_mat = np.zeros((n, n))
180 | for i in range(n):
181 | for j in range(n):
182 | grad_mat[i, j] = dps[(i, j)]
183 |
184 | while iter_count < MinNormSolver.MAX_ITER:
185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec))
186 |
187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
189 | v2v2 = grad_mat[t_iter, t_iter]
190 |
191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
192 | new_sol_vec = nc * sol_vec
193 | new_sol_vec[t_iter] += 1 - nc
194 |
195 | change = new_sol_vec - sol_vec
196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
197 | return sol_vec, nd
198 | sol_vec = new_sol_vec
199 |
200 |
201 | def gradient_normalizers(grads, losses, normalization_type):
202 | gn = {}
203 | if normalization_type == "norm":
204 | for t in grads:
205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
206 | elif normalization_type == "loss":
207 | for t in grads:
208 | gn[t] = losses[t]
209 | elif normalization_type == "loss+":
210 | for t in grads:
211 | gn[t] = losses[t] * np.sqrt(
212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])
213 | )
214 | elif normalization_type == "none":
215 | for t in grads:
216 | gn[t] = 1.0
217 | else:
218 | print("ERROR: Invalid Normalization Type")
219 | return gn
220 |
--------------------------------------------------------------------------------
/experiments/nyuv2/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import wandb
3 | from argparse import ArgumentParser
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data import DataLoader
9 | from tqdm import trange
10 |
11 | from experiments.nyuv2.data import NYUv2
12 | from experiments.nyuv2.models import SegNet, SegNetMtan
13 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error
14 | from experiments.utils import (
15 | common_parser,
16 | extract_weight_method_parameters_from_args,
17 | get_device,
18 | set_logger,
19 | set_seed,
20 | str2bool,
21 | )
22 | from methods.weight_methods import WeightMethods
23 |
24 | set_logger()
25 |
26 |
27 | def calc_loss(x_pred, x_output, task_type):
28 | device = x_pred.device
29 |
30 | # binary mark to mask out undefined pixel space
31 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device)
32 |
33 | if task_type == "semantic":
34 | # semantic loss: depth-wise cross entropy
35 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
36 |
37 | if task_type == "depth":
38 | # depth loss: l1 norm
39 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(
40 | binary_mask, as_tuple=False
41 | ).size(0)
42 |
43 | if task_type == "normal":
44 | # normal loss: dot product
45 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(
46 | binary_mask, as_tuple=False
47 | ).size(0)
48 |
49 | return loss
50 |
51 |
52 | def main(path, lr, bs, device):
53 | # ----
54 | # Nets
55 | # ---
56 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model]
57 | model = model.to(device)
58 |
59 | # dataset and dataloaders
60 | log_str = (
61 | "Applying data augmentation on NYUv2."
62 | if args.apply_augmentation
63 | else "Standard training strategy without data augmentation."
64 | )
65 | logging.info(log_str)
66 |
67 | nyuv2_train_set = NYUv2(
68 | root=path.as_posix(), train=True, augmentation=args.apply_augmentation
69 | )
70 | nyuv2_test_set = NYUv2(root=path.as_posix(), train=False)
71 |
72 | train_loader = torch.utils.data.DataLoader(
73 | dataset=nyuv2_train_set, batch_size=bs, shuffle=True
74 | )
75 |
76 | test_loader = torch.utils.data.DataLoader(
77 | dataset=nyuv2_test_set, batch_size=bs, shuffle=False
78 | )
79 |
80 | # weight method
81 | weight_methods_parameters = extract_weight_method_parameters_from_args(args)
82 | weight_method = WeightMethods(
83 | args.method, n_tasks=3, device=device, **weight_methods_parameters[args.method]
84 | )
85 |
86 | # optimizer
87 | optimizer = torch.optim.Adam(
88 | [
89 | dict(params=model.parameters(), lr=lr),
90 | dict(params=weight_method.parameters(), lr=args.method_params_lr),
91 | ],
92 | )
93 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
94 |
95 | epochs = args.n_epochs
96 | epoch_iter = trange(epochs)
97 | train_batch = len(train_loader)
98 | test_batch = len(test_loader)
99 | avg_cost = np.zeros([epochs, 24], dtype=np.float32)
100 | custom_step = -1
101 | conf_mat = ConfMatrix(model.segnet.class_nb)
102 | for epoch in epoch_iter:
103 | cost = np.zeros(24, dtype=np.float32)
104 |
105 | for j, batch in enumerate(train_loader):
106 | custom_step += 1
107 |
108 | model.train()
109 | optimizer.zero_grad()
110 |
111 | train_data, train_label, train_depth, train_normal = batch
112 | train_data, train_label = train_data.to(device), train_label.long().to(
113 | device
114 | )
115 | train_depth, train_normal = train_depth.to(device), train_normal.to(device)
116 |
117 | train_pred, features = model(train_data, return_representation=True)
118 |
119 | losses = torch.stack(
120 | (
121 | calc_loss(train_pred[0], train_label, "semantic"),
122 | calc_loss(train_pred[1], train_depth, "depth"),
123 | calc_loss(train_pred[2], train_normal, "normal"),
124 | )
125 | )
126 |
127 | loss, extra_outputs = weight_method.backward(
128 | losses=losses,
129 | shared_parameters=list(model.shared_parameters()),
130 | task_specific_parameters=list(model.task_specific_parameters()),
131 | last_shared_parameters=list(model.last_shared_parameters()),
132 | representation=features,
133 | )
134 |
135 | optimizer.step()
136 |
137 | # accumulate label prediction for every pixel in training images
138 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
139 |
140 | cost[0] = losses[0].item()
141 | cost[3] = losses[1].item()
142 | cost[4], cost[5] = depth_error(train_pred[1], train_depth)
143 | cost[6] = losses[2].item()
144 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(
145 | train_pred[2], train_normal
146 | )
147 | avg_cost[epoch, :12] += cost[:12] / train_batch
148 |
149 | epoch_iter.set_description(
150 | f"[{epoch+1} {j+1}/{train_batch}] semantic loss: {losses[0].item():.3f}, "
151 | f"depth loss: {losses[1].item():.3f}, "
152 | f"normal loss: {losses[2].item():.3f}"
153 | )
154 |
155 | # scheduler
156 | scheduler.step()
157 | # compute mIoU and acc
158 | avg_cost[epoch, 1:3] = conf_mat.get_metrics()
159 |
160 | # todo: move evaluate to function?
161 | # evaluating test data
162 | model.eval()
163 | conf_mat = ConfMatrix(model.segnet.class_nb)
164 | with torch.no_grad(): # operations inside don't track history
165 | test_dataset = iter(test_loader)
166 | for k in range(test_batch):
167 | test_data, test_label, test_depth, test_normal = next(test_dataset) #test_dataset.next()#.next is deprecated
168 | test_data, test_label = test_data.to(device), test_label.long().to(
169 | device
170 | )
171 | test_depth, test_normal = test_depth.to(device), test_normal.to(device)
172 |
173 | test_pred = model(test_data)
174 | test_loss = torch.stack(
175 | (
176 | calc_loss(test_pred[0], test_label, "semantic"),
177 | calc_loss(test_pred[1], test_depth, "depth"),
178 | calc_loss(test_pred[2], test_normal, "normal"),
179 | )
180 | )
181 |
182 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
183 |
184 | cost[12] = test_loss[0].item()
185 | cost[15] = test_loss[1].item()
186 | cost[16], cost[17] = depth_error(test_pred[1], test_depth)
187 | cost[18] = test_loss[2].item()
188 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(
189 | test_pred[2], test_normal
190 | )
191 | avg_cost[epoch, 12:] += cost[12:] / test_batch
192 |
193 | # compute mIoU and acc
194 | avg_cost[epoch, 13:15] = conf_mat.get_metrics()
195 |
196 | # Test Delta_m
197 | test_delta_m = delta_fn(
198 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]]
199 | )
200 |
201 | # print results
202 | print(
203 | f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR "
204 | f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m (test)"
205 | )
206 | print(
207 | f"Epoch: {epoch:04d} | TRAIN: {avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]:.4f} {avg_cost[epoch, 2]:.4f} "
208 | f"| {avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | {avg_cost[epoch, 6]:.4f} "
209 | f"{avg_cost[epoch, 7]:.4f} {avg_cost[epoch, 8]:.4f} {avg_cost[epoch, 9]:.4f} {avg_cost[epoch, 10]:.4f} {avg_cost[epoch, 11]:.4f} || "
210 | f"TEST: {avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]:.4f} {avg_cost[epoch, 14]:.4f} | "
211 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | {avg_cost[epoch, 18]:.4f} "
212 | f"{avg_cost[epoch, 19]:.4f} {avg_cost[epoch, 20]:.4f} {avg_cost[epoch, 21]:.4f} {avg_cost[epoch, 22]:.4f} {avg_cost[epoch, 23]:.4f} "
213 | f"| {test_delta_m:.3f}"
214 | )
215 |
216 | if wandb.run is not None:
217 | wandb.log({"Train Semantic Loss": avg_cost[epoch, 0]}, step=epoch)
218 | wandb.log({"Train Mean IoU": avg_cost[epoch, 1]}, step=epoch)
219 | wandb.log({"Train Pixel Accuracy": avg_cost[epoch, 2]}, step=epoch)
220 | wandb.log({"Train Depth Loss": avg_cost[epoch, 3]}, step=epoch)
221 | wandb.log({"Train Absolute Error": avg_cost[epoch, 4]}, step=epoch)
222 | wandb.log({"Train Relative Error": avg_cost[epoch, 5]}, step=epoch)
223 | wandb.log({"Train Normal Loss": avg_cost[epoch, 6]}, step=epoch)
224 | wandb.log({"Train Loss Mean": avg_cost[epoch, 7]}, step=epoch)
225 | wandb.log({"Train Loss Med": avg_cost[epoch, 8]}, step=epoch)
226 | wandb.log({"Train Loss <11.25": avg_cost[epoch, 9]}, step=epoch)
227 | wandb.log({"Train Loss <22.5": avg_cost[epoch, 10]}, step=epoch)
228 | wandb.log({"Train Loss <30": avg_cost[epoch, 11]}, step=epoch)
229 |
230 | wandb.log({"Test Semantic Loss": avg_cost[epoch, 12]}, step=epoch)
231 | wandb.log({"Test Mean IoU": avg_cost[epoch, 13]}, step=epoch)
232 | wandb.log({"Test Pixel Accuracy": avg_cost[epoch, 14]}, step=epoch)
233 | wandb.log({"Test Depth Loss": avg_cost[epoch, 15]}, step=epoch)
234 | wandb.log({"Test Absolute Error": avg_cost[epoch, 16]}, step=epoch)
235 | wandb.log({"Test Relative Error": avg_cost[epoch, 17]}, step=epoch)
236 | wandb.log({"Test Normal Loss": avg_cost[epoch, 18]}, step=epoch)
237 | wandb.log({"Test Loss Mean": avg_cost[epoch, 19]}, step=epoch)
238 | wandb.log({"Test Loss Med": avg_cost[epoch, 20]}, step=epoch)
239 | wandb.log({"Test Loss <11.25": avg_cost[epoch, 21]}, step=epoch)
240 | wandb.log({"Test Loss <22.5": avg_cost[epoch, 22]}, step=epoch)
241 | wandb.log({"Test Loss <30": avg_cost[epoch, 23]}, step=epoch)
242 | wandb.log({"Test ∆m": test_delta_m}, step=epoch)
243 |
244 |
245 | if __name__ == "__main__":
246 | parser = ArgumentParser("NYUv2", parents=[common_parser])
247 | parser.set_defaults(
248 | data_path="dataset",
249 | lr=1e-4,
250 | n_epochs=200,
251 | batch_size=2,
252 | )
253 | parser.add_argument(
254 | "--model",
255 | type=str,
256 | default="mtan",
257 | choices=["segnet", "mtan"],
258 | help="model type",
259 | )
260 | parser.add_argument(
261 | "--apply-augmentation", type=str2bool, default=True, help="data augmentations"
262 | )
263 | parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.")
264 | parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.")
265 | args = parser.parse_args()
266 |
267 | # set seed
268 | set_seed(args.seed)
269 |
270 | if args.wandb_project is not None:
271 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args)
272 |
273 | device = get_device(gpus=args.gpu)
274 | main(path=args.data_path, lr=args.lr, bs=args.batch_size, device=device)
275 |
276 | if wandb.run is not None:
277 | wandb.finish()
278 |
--------------------------------------------------------------------------------
/experiments/nyuv2/models.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class _SegNet(nn.Module):
9 | """SegNet MTAN"""
10 |
11 | def __init__(self):
12 | super(_SegNet, self).__init__()
13 | # initialise network parameters
14 | filter = [64, 128, 256, 512, 512]
15 | self.class_nb = 13
16 |
17 | # define encoder decoder layers
18 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
19 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
20 | for i in range(4):
21 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
22 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
23 |
24 | # define convolution layer
25 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
26 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
27 | for i in range(4):
28 | if i == 0:
29 | self.conv_block_enc.append(
30 | self.conv_layer([filter[i + 1], filter[i + 1]])
31 | )
32 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
33 | else:
34 | self.conv_block_enc.append(
35 | nn.Sequential(
36 | self.conv_layer([filter[i + 1], filter[i + 1]]),
37 | self.conv_layer([filter[i + 1], filter[i + 1]]),
38 | )
39 | )
40 | self.conv_block_dec.append(
41 | nn.Sequential(
42 | self.conv_layer([filter[i], filter[i]]),
43 | self.conv_layer([filter[i], filter[i]]),
44 | )
45 | )
46 |
47 | # define task attention layers
48 | self.encoder_att = nn.ModuleList(
49 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])]
50 | )
51 | self.decoder_att = nn.ModuleList(
52 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])]
53 | )
54 | self.encoder_block_att = nn.ModuleList(
55 | [self.conv_layer([filter[0], filter[1]])]
56 | )
57 | self.decoder_block_att = nn.ModuleList(
58 | [self.conv_layer([filter[0], filter[0]])]
59 | )
60 |
61 | for j in range(3):
62 | if j < 2:
63 | self.encoder_att.append(
64 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])
65 | )
66 | self.decoder_att.append(
67 | nn.ModuleList(
68 | [self.att_layer([2 * filter[0], filter[0], filter[0]])]
69 | )
70 | )
71 | for i in range(4):
72 | self.encoder_att[j].append(
73 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]])
74 | )
75 | self.decoder_att[j].append(
76 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]])
77 | )
78 |
79 | for i in range(4):
80 | if i < 3:
81 | self.encoder_block_att.append(
82 | self.conv_layer([filter[i + 1], filter[i + 2]])
83 | )
84 | self.decoder_block_att.append(
85 | self.conv_layer([filter[i + 1], filter[i]])
86 | )
87 | else:
88 | self.encoder_block_att.append(
89 | self.conv_layer([filter[i + 1], filter[i + 1]])
90 | )
91 | self.decoder_block_att.append(
92 | self.conv_layer([filter[i + 1], filter[i + 1]])
93 | )
94 |
95 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
96 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)
97 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)
98 |
99 | # define pooling and unpooling functions
100 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
101 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | nn.init.xavier_normal_(m.weight)
106 | nn.init.constant_(m.bias, 0)
107 | elif isinstance(m, nn.BatchNorm2d):
108 | nn.init.constant_(m.weight, 1)
109 | nn.init.constant_(m.bias, 0)
110 | elif isinstance(m, nn.Linear):
111 | nn.init.xavier_normal_(m.weight)
112 | nn.init.constant_(m.bias, 0)
113 |
114 | def shared_modules(self):
115 | return [
116 | self.encoder_block,
117 | self.decoder_block,
118 | self.conv_block_enc,
119 | self.conv_block_dec,
120 | # self.encoder_att, self.decoder_att,
121 | self.encoder_block_att,
122 | self.decoder_block_att,
123 | self.down_sampling,
124 | self.up_sampling,
125 | ]
126 |
127 | def zero_grad_shared_modules(self):
128 | for mm in self.shared_modules():
129 | mm.zero_grad()
130 |
131 | def conv_layer(self, channel, pred=False):
132 | if not pred:
133 | conv_block = nn.Sequential(
134 | nn.Conv2d(
135 | in_channels=channel[0],
136 | out_channels=channel[1],
137 | kernel_size=3,
138 | padding=1,
139 | ),
140 | nn.BatchNorm2d(num_features=channel[1]),
141 | nn.ReLU(inplace=True),
142 | )
143 | else:
144 | conv_block = nn.Sequential(
145 | nn.Conv2d(
146 | in_channels=channel[0],
147 | out_channels=channel[0],
148 | kernel_size=3,
149 | padding=1,
150 | ),
151 | nn.Conv2d(
152 | in_channels=channel[0],
153 | out_channels=channel[1],
154 | kernel_size=1,
155 | padding=0,
156 | ),
157 | )
158 | return conv_block
159 |
160 | def att_layer(self, channel):
161 | att_block = nn.Sequential(
162 | nn.Conv2d(
163 | in_channels=channel[0],
164 | out_channels=channel[1],
165 | kernel_size=1,
166 | padding=0,
167 | ),
168 | nn.BatchNorm2d(channel[1]),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(
171 | in_channels=channel[1],
172 | out_channels=channel[2],
173 | kernel_size=1,
174 | padding=0,
175 | ),
176 | nn.BatchNorm2d(channel[2]),
177 | nn.Sigmoid(),
178 | )
179 | return att_block
180 |
181 | def forward(self, x):
182 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
183 | [0] * 5 for _ in range(5)
184 | )
185 | for i in range(5):
186 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
187 |
188 | # define attention list for tasks
189 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2))
190 | for i in range(3):
191 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
192 | for i in range(3):
193 | for j in range(5):
194 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))
195 |
196 | # define global shared network
197 | for i in range(5):
198 | if i == 0:
199 | g_encoder[i][0] = self.encoder_block[i](x)
200 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
201 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
202 | else:
203 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
204 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
205 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
206 |
207 | for i in range(5):
208 | if i == 0:
209 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
210 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
211 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
212 | else:
213 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
214 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
215 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
216 |
217 | # define task dependent attention module
218 | for i in range(3):
219 | for j in range(5):
220 | if j == 0:
221 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
222 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
223 | atten_encoder[i][j][2] = self.encoder_block_att[j](
224 | atten_encoder[i][j][1]
225 | )
226 | atten_encoder[i][j][2] = F.max_pool2d(
227 | atten_encoder[i][j][2], kernel_size=2, stride=2
228 | )
229 | else:
230 | atten_encoder[i][j][0] = self.encoder_att[i][j](
231 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1)
232 | )
233 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
234 | atten_encoder[i][j][2] = self.encoder_block_att[j](
235 | atten_encoder[i][j][1]
236 | )
237 | atten_encoder[i][j][2] = F.max_pool2d(
238 | atten_encoder[i][j][2], kernel_size=2, stride=2
239 | )
240 |
241 | for j in range(5):
242 | if j == 0:
243 | atten_decoder[i][j][0] = F.interpolate(
244 | atten_encoder[i][-1][-1],
245 | scale_factor=2,
246 | mode="bilinear",
247 | align_corners=True,
248 | )
249 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
250 | atten_decoder[i][j][0]
251 | )
252 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
253 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
254 | )
255 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
256 | else:
257 | atten_decoder[i][j][0] = F.interpolate(
258 | atten_decoder[i][j - 1][2],
259 | scale_factor=2,
260 | mode="bilinear",
261 | align_corners=True,
262 | )
263 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
264 | atten_decoder[i][j][0]
265 | )
266 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
267 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
268 | )
269 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
270 |
271 | # define task prediction layers
272 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1)
273 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1])
274 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1])
275 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
276 |
277 | return (
278 | [t1_pred, t2_pred, t3_pred],
279 | (
280 | atten_decoder[0][-1][-1],
281 | atten_decoder[1][-1][-1],
282 | atten_decoder[2][-1][-1],
283 | ),
284 | )
285 |
286 |
287 | class SegNetMtan(nn.Module):
288 | def __init__(self):
289 | super().__init__()
290 | self.segnet = _SegNet()
291 |
292 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
293 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
294 |
295 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
296 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
297 |
298 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
299 | """Parameters of the last shared layer.
300 | Returns
301 | -------
302 | """
303 | return []
304 |
305 | def forward(self, x, return_representation=False):
306 | if return_representation:
307 | return self.segnet(x)
308 | else:
309 | pred, rep = self.segnet(x)
310 | return pred
311 |
312 |
313 | class SegNetSplit(nn.Module):
314 | def __init__(self, model_type="standard"):
315 | super(SegNetSplit, self).__init__()
316 | # initialise network parameters
317 | assert model_type in ["standard", "wide", "deep"]
318 | self.model_type = model_type
319 | if self.model_type == "wide":
320 | filter = [64, 128, 256, 512, 1024]
321 | else:
322 | filter = [64, 128, 256, 512, 512]
323 |
324 | self.class_nb = 13
325 |
326 | # define encoder decoder layers
327 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
328 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
329 | for i in range(4):
330 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
331 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
332 |
333 | # define convolution layer
334 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
335 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
336 | for i in range(4):
337 | if i == 0:
338 | self.conv_block_enc.append(
339 | self.conv_layer([filter[i + 1], filter[i + 1]])
340 | )
341 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
342 | else:
343 | self.conv_block_enc.append(
344 | nn.Sequential(
345 | self.conv_layer([filter[i + 1], filter[i + 1]]),
346 | self.conv_layer([filter[i + 1], filter[i + 1]]),
347 | )
348 | )
349 | self.conv_block_dec.append(
350 | nn.Sequential(
351 | self.conv_layer([filter[i], filter[i]]),
352 | self.conv_layer([filter[i], filter[i]]),
353 | )
354 | )
355 |
356 | # define task specific layers
357 | self.pred_task1 = nn.Sequential(
358 | nn.Conv2d(
359 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
360 | ),
361 | nn.Conv2d(
362 | in_channels=filter[0],
363 | out_channels=self.class_nb,
364 | kernel_size=1,
365 | padding=0,
366 | ),
367 | )
368 | self.pred_task2 = nn.Sequential(
369 | nn.Conv2d(
370 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
371 | ),
372 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0),
373 | )
374 | self.pred_task3 = nn.Sequential(
375 | nn.Conv2d(
376 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
377 | ),
378 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0),
379 | )
380 |
381 | # define pooling and unpooling functions
382 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
383 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
384 |
385 | for m in self.modules():
386 | if isinstance(m, nn.Conv2d):
387 | nn.init.xavier_normal_(m.weight)
388 | nn.init.constant_(m.bias, 0)
389 | elif isinstance(m, nn.BatchNorm2d):
390 | nn.init.constant_(m.weight, 1)
391 | nn.init.constant_(m.bias, 0)
392 | elif isinstance(m, nn.Linear):
393 | nn.init.xavier_normal_(m.weight)
394 | nn.init.constant_(m.bias, 0)
395 |
396 | # define convolutional block
397 | def conv_layer(self, channel):
398 | if self.model_type == "deep":
399 | conv_block = nn.Sequential(
400 | nn.Conv2d(
401 | in_channels=channel[0],
402 | out_channels=channel[1],
403 | kernel_size=3,
404 | padding=1,
405 | ),
406 | nn.BatchNorm2d(num_features=channel[1]),
407 | nn.ReLU(inplace=True),
408 | nn.Conv2d(
409 | in_channels=channel[1],
410 | out_channels=channel[1],
411 | kernel_size=3,
412 | padding=1,
413 | ),
414 | nn.BatchNorm2d(num_features=channel[1]),
415 | nn.ReLU(inplace=True),
416 | )
417 | else:
418 | conv_block = nn.Sequential(
419 | nn.Conv2d(
420 | in_channels=channel[0],
421 | out_channels=channel[1],
422 | kernel_size=3,
423 | padding=1,
424 | ),
425 | nn.BatchNorm2d(num_features=channel[1]),
426 | nn.ReLU(inplace=True),
427 | )
428 | return conv_block
429 |
430 | def forward(self, x):
431 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
432 | [0] * 5 for _ in range(5)
433 | )
434 | for i in range(5):
435 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
436 |
437 | # global shared encoder-decoder network
438 | for i in range(5):
439 | if i == 0:
440 | g_encoder[i][0] = self.encoder_block[i](x)
441 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
442 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
443 | else:
444 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
445 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
446 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
447 |
448 | for i in range(5):
449 | if i == 0:
450 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
451 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
452 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
453 | else:
454 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
455 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
456 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
457 |
458 | # define task prediction layers
459 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1)
460 | t2_pred = self.pred_task2(g_decoder[i][1])
461 | t3_pred = self.pred_task3(g_decoder[i][1])
462 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
463 |
464 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][
465 | 1
466 | ] # NOTE: last element is representation
467 |
468 |
469 | class SegNet(nn.Module):
470 | def __init__(self):
471 | super().__init__()
472 | self.segnet = SegNetSplit()
473 |
474 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
475 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
476 |
477 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
478 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
479 |
480 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
481 | """Parameters of the last shared layer.
482 | Returns
483 | -------
484 | """
485 | return self.segnet.conv_block_dec[-5].parameters()
486 |
487 | def forward(self, x, return_representation=False):
488 | if return_representation:
489 | return self.segnet(x)
490 | else:
491 | pred, rep = self.segnet(x)
492 | return pred
493 |
--------------------------------------------------------------------------------
/methods/weight_methods.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from abc import abstractmethod
4 | from typing import Dict, List, Tuple, Union
5 |
6 | import cvxpy as cp
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from scipy.optimize import minimize
11 |
12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers
13 |
14 |
15 | class WeightMethod:
16 | def __init__(self, n_tasks: int, device: torch.device):
17 | super().__init__()
18 | self.n_tasks = n_tasks
19 | self.device = device
20 |
21 | @abstractmethod
22 | def get_weighted_loss(
23 | self,
24 | losses: torch.Tensor,
25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
26 | task_specific_parameters: Union[
27 | List[torch.nn.parameter.Parameter], torch.Tensor
28 | ],
29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
31 | **kwargs,
32 | ):
33 | pass
34 |
35 | def backward(
36 | self,
37 | losses: torch.Tensor,
38 | shared_parameters: Union[
39 | List[torch.nn.parameter.Parameter], torch.Tensor
40 | ] = None,
41 | task_specific_parameters: Union[
42 | List[torch.nn.parameter.Parameter], torch.Tensor
43 | ] = None,
44 | last_shared_parameters: Union[
45 | List[torch.nn.parameter.Parameter], torch.Tensor
46 | ] = None,
47 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
48 | **kwargs,
49 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
50 | """
51 |
52 | Parameters
53 | ----------
54 | losses :
55 | shared_parameters :
56 | task_specific_parameters :
57 | last_shared_parameters : parameters of last shared layer/block
58 | representation : shared representation
59 | kwargs :
60 |
61 | Returns
62 | -------
63 | Loss, extra outputs
64 | """
65 | loss, extra_outputs = self.get_weighted_loss(
66 | losses=losses,
67 | shared_parameters=shared_parameters,
68 | task_specific_parameters=task_specific_parameters,
69 | last_shared_parameters=last_shared_parameters,
70 | representation=representation,
71 | **kwargs,
72 | )
73 | loss.backward()
74 | return loss, extra_outputs
75 |
76 | def __call__(
77 | self,
78 | losses: torch.Tensor,
79 | shared_parameters: Union[
80 | List[torch.nn.parameter.Parameter], torch.Tensor
81 | ] = None,
82 | task_specific_parameters: Union[
83 | List[torch.nn.parameter.Parameter], torch.Tensor
84 | ] = None,
85 | **kwargs,
86 | ):
87 | return self.backward(
88 | losses=losses,
89 | shared_parameters=shared_parameters,
90 | task_specific_parameters=task_specific_parameters,
91 | **kwargs,
92 | )
93 |
94 | def parameters(self) -> List[torch.Tensor]:
95 | """return learnable parameters"""
96 | return []
97 |
98 |
99 | class NashMTL(WeightMethod):
100 | def __init__(
101 | self,
102 | n_tasks: int,
103 | device: torch.device,
104 | max_norm: float = 1.0,
105 | update_weights_every: int = 1,
106 | optim_niter=20,
107 | ):
108 | super(NashMTL, self).__init__(
109 | n_tasks=n_tasks,
110 | device=device,
111 | )
112 |
113 | self.optim_niter = optim_niter
114 | self.update_weights_every = update_weights_every
115 | self.max_norm = max_norm
116 |
117 | self.prvs_alpha_param = None
118 | self.normalization_factor = np.ones((1,))
119 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks)
120 | self.step = 0.0
121 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
122 |
123 | def _stop_criteria(self, gtg, alpha_t):
124 | return (
125 | (self.alpha_param.value is None)
126 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
127 | or (
128 | np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value)
129 | < 1e-6
130 | )
131 | )
132 |
133 | def solve_optimization(self, gtg: np.array):
134 | self.G_param.value = gtg
135 | self.normalization_factor_param.value = self.normalization_factor
136 |
137 | alpha_t = self.prvs_alpha
138 | for _ in range(self.optim_niter):
139 | self.alpha_param.value = alpha_t
140 | self.prvs_alpha_param.value = alpha_t
141 |
142 | try:
143 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
144 | except:
145 | self.alpha_param.value = self.prvs_alpha_param.value
146 |
147 | if self._stop_criteria(gtg, alpha_t):
148 | break
149 |
150 | alpha_t = self.alpha_param.value
151 |
152 | if alpha_t is not None:
153 | self.prvs_alpha = alpha_t
154 |
155 | return self.prvs_alpha
156 |
157 | def _calc_phi_alpha_linearization(self):
158 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param
159 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param
160 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param)
161 | return phi_alpha
162 |
163 | def _init_optim_problem(self):
164 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True)
165 | self.prvs_alpha_param = cp.Parameter(
166 | shape=(self.n_tasks,), value=self.prvs_alpha
167 | )
168 | self.G_param = cp.Parameter(
169 | shape=(self.n_tasks, self.n_tasks), value=self.init_gtg
170 | )
171 | self.normalization_factor_param = cp.Parameter(
172 | shape=(1,), value=np.array([1.0])
173 | )
174 |
175 | self.phi_alpha = self._calc_phi_alpha_linearization()
176 |
177 | G_alpha = self.G_param @ self.alpha_param
178 | constraint = []
179 | for i in range(self.n_tasks):
180 | constraint.append(
181 | -cp.log(self.alpha_param[i] * self.normalization_factor_param)
182 | - cp.log(G_alpha[i])
183 | <= 0
184 | )
185 | obj = cp.Minimize(
186 | cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param
187 | )
188 | self.prob = cp.Problem(obj, constraint)
189 |
190 | def get_weighted_loss(
191 | self,
192 | losses,
193 | shared_parameters,
194 | **kwargs,
195 | ):
196 | """
197 |
198 | Parameters
199 | ----------
200 | losses :
201 | shared_parameters : shared parameters
202 | kwargs :
203 |
204 | Returns
205 | -------
206 |
207 | """
208 |
209 | extra_outputs = dict()
210 | if self.step == 0:
211 | self._init_optim_problem()
212 |
213 | if (self.step % self.update_weights_every) == 0:
214 | self.step += 1
215 |
216 | grads = {}
217 | for i, loss in enumerate(losses):
218 | g = list(
219 | torch.autograd.grad(
220 | loss,
221 | shared_parameters,
222 | retain_graph=True,
223 | )
224 | )
225 | grad = torch.cat([torch.flatten(grad) for grad in g])
226 | grads[i] = grad
227 |
228 | G = torch.stack(tuple(v for v in grads.values()))
229 | GTG = torch.mm(G, G.t())
230 |
231 | self.normalization_factor = (
232 | torch.norm(GTG).detach().cpu().numpy().reshape((1,))
233 | )
234 | GTG = GTG / self.normalization_factor.item()
235 | alpha = self.solve_optimization(GTG.cpu().detach().numpy())
236 | alpha = torch.from_numpy(alpha)
237 |
238 | else:
239 | self.step += 1
240 | alpha = self.prvs_alpha
241 |
242 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])
243 | extra_outputs["weights"] = alpha
244 | return weighted_loss, extra_outputs
245 |
246 | def backward(
247 | self,
248 | losses: torch.Tensor,
249 | shared_parameters: Union[
250 | List[torch.nn.parameter.Parameter], torch.Tensor
251 | ] = None,
252 | task_specific_parameters: Union[
253 | List[torch.nn.parameter.Parameter], torch.Tensor
254 | ] = None,
255 | last_shared_parameters: Union[
256 | List[torch.nn.parameter.Parameter], torch.Tensor
257 | ] = None,
258 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
259 | **kwargs,
260 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
261 | loss, extra_outputs = self.get_weighted_loss(
262 | losses=losses,
263 | shared_parameters=shared_parameters,
264 | **kwargs,
265 | )
266 | loss.backward()
267 |
268 | # make sure the solution for shared params has norm <= self.eps
269 | if self.max_norm > 0:
270 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm)
271 |
272 | return loss, extra_outputs
273 |
274 |
275 | class LinearScalarization(WeightMethod):
276 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
277 |
278 | def __init__(
279 | self,
280 | n_tasks: int,
281 | device: torch.device,
282 | task_weights: Union[List[float], torch.Tensor] = None,
283 | ):
284 | super().__init__(n_tasks, device=device)
285 | if task_weights is None:
286 | task_weights = torch.ones((n_tasks,))
287 | if not isinstance(task_weights, torch.Tensor):
288 | task_weights = torch.tensor(task_weights)
289 | assert len(task_weights) == n_tasks
290 | self.task_weights = task_weights.to(device)
291 |
292 | def get_weighted_loss(self, losses, **kwargs):
293 | loss = torch.sum(losses * self.task_weights)
294 | return loss, dict(weights=self.task_weights)
295 |
296 |
297 | class ScaleInvariantLinearScalarization(WeightMethod):
298 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
299 |
300 | def __init__(
301 | self,
302 | n_tasks: int,
303 | device: torch.device,
304 | task_weights: Union[List[float], torch.Tensor] = None,
305 | ):
306 | super().__init__(n_tasks, device=device)
307 | if task_weights is None:
308 | task_weights = torch.ones((n_tasks,))
309 | if not isinstance(task_weights, torch.Tensor):
310 | task_weights = torch.tensor(task_weights)
311 | assert len(task_weights) == n_tasks
312 | self.task_weights = task_weights.to(device)
313 |
314 | def get_weighted_loss(self, losses, **kwargs):
315 | loss = torch.sum(torch.log(losses) * self.task_weights)
316 | return loss, dict(weights=self.task_weights)
317 |
318 |
319 | class MGDA(WeightMethod):
320 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization
321 | Ozan Sener, Vladlen Koltun
322 | Neural Information Processing Systems (NeurIPS) 2018
323 | https://github.com/intel-isl/MultiObjectiveOptimization
324 |
325 | """
326 |
327 | def __init__(
328 | self, n_tasks, device: torch.device, params="shared", normalization="none"
329 | ):
330 | super().__init__(n_tasks, device=device)
331 | self.solver = MinNormSolver()
332 | assert params in ["shared", "last", "rep"]
333 | self.params = params
334 | assert normalization in ["norm", "loss", "loss+", "none"]
335 | self.normalization = normalization
336 |
337 | @staticmethod
338 | def _flattening(grad):
339 | return torch.cat(
340 | tuple(
341 | g.reshape(
342 | -1,
343 | )
344 | for i, g in enumerate(grad)
345 | ),
346 | dim=0,
347 | )
348 |
349 | def get_weighted_loss(
350 | self,
351 | losses,
352 | shared_parameters=None,
353 | last_shared_parameters=None,
354 | representation=None,
355 | **kwargs,
356 | ):
357 | """
358 |
359 | Parameters
360 | ----------
361 | losses :
362 | shared_parameters :
363 | last_shared_parameters :
364 | representation :
365 | kwargs :
366 |
367 | Returns
368 | -------
369 |
370 | """
371 | # Our code
372 | grads = {}
373 | params = dict(
374 | rep=representation, shared=shared_parameters, last=last_shared_parameters
375 | )[self.params]
376 | for i, loss in enumerate(losses):
377 | g = list(
378 | torch.autograd.grad(
379 | loss,
380 | params,
381 | retain_graph=True,
382 | )
383 | )
384 | # Normalize all gradients, this is optional and not included in the paper.
385 |
386 | grads[i] = [torch.flatten(grad) for grad in g]
387 |
388 | gn = gradient_normalizers(grads, losses, self.normalization)
389 | for t in range(self.n_tasks):
390 | for gr_i in range(len(grads[t])):
391 | grads[t][gr_i] = grads[t][gr_i] / gn[t]
392 |
393 | sol, min_norm = self.solver.find_min_norm_element(
394 | [grads[t] for t in range(len(grads))]
395 | )
396 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks
397 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))])
398 |
399 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32)))
400 |
401 |
402 | class STL(WeightMethod):
403 | """Single task learning"""
404 |
405 | def __init__(self, n_tasks, device: torch.device, main_task):
406 | super().__init__(n_tasks, device=device)
407 | self.main_task = main_task
408 | self.weights = torch.zeros(n_tasks, device=device)
409 | self.weights[main_task] = 1.0
410 |
411 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
412 | assert len(losses) == self.n_tasks
413 | loss = losses[self.main_task]
414 |
415 | return loss, dict(weights=self.weights)
416 |
417 |
418 | class Uncertainty(WeightMethod):
419 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics`
420 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
421 | """
422 |
423 | def __init__(self, n_tasks, device: torch.device):
424 | super().__init__(n_tasks, device=device)
425 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True)
426 |
427 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
428 | loss = sum(
429 | [
430 | 0.5 * (torch.exp(-logs) * loss + logs)
431 | for loss, logs in zip(losses, self.logsigma)
432 | ]
433 | )
434 |
435 | return loss, dict(
436 | weights=torch.exp(-self.logsigma)
437 | ) # NOTE: not exactly task weights
438 |
439 | def parameters(self) -> List[torch.Tensor]:
440 | return [self.logsigma]
441 |
442 |
443 | class PCGrad(WeightMethod):
444 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py
445 |
446 | @misc{Pytorch-PCGrad,
447 | author = {Wei-Cheng Tseng},
448 | title = {WeiChengTseng/Pytorch-PCGrad},
449 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git},
450 | year = {2020}
451 | }
452 |
453 | """
454 |
455 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"):
456 | super().__init__(n_tasks, device=device)
457 | assert reduction in ["mean", "sum"]
458 | self.reduction = reduction
459 |
460 | def get_weighted_loss(
461 | self,
462 | losses: torch.Tensor,
463 | shared_parameters: Union[
464 | List[torch.nn.parameter.Parameter], torch.Tensor
465 | ] = None,
466 | task_specific_parameters: Union[
467 | List[torch.nn.parameter.Parameter], torch.Tensor
468 | ] = None,
469 | **kwargs,
470 | ):
471 | raise NotImplementedError
472 |
473 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None):
474 | # shared part
475 | shared_grads = []
476 | for l in losses:
477 | shared_grads.append(
478 | torch.autograd.grad(l, shared_parameters, retain_graph=True)
479 | )
480 |
481 | if isinstance(shared_parameters, torch.Tensor):
482 | shared_parameters = [shared_parameters]
483 | non_conflict_shared_grads = self._project_conflicting(shared_grads)
484 | for p, g in zip(shared_parameters, non_conflict_shared_grads):
485 | p.grad = g
486 |
487 | # task specific part
488 | if task_specific_parameters is not None:
489 | task_specific_grads = torch.autograd.grad(
490 | losses.sum(), task_specific_parameters
491 | )
492 | if isinstance(task_specific_parameters, torch.Tensor):
493 | task_specific_parameters = [task_specific_parameters]
494 | for p, g in zip(task_specific_parameters, task_specific_grads):
495 | p.grad = g
496 |
497 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]):
498 | pc_grad = copy.deepcopy(grads)
499 | for g_i in pc_grad:
500 | random.shuffle(grads)
501 | for g_j in grads:
502 | g_i_g_j = sum(
503 | [
504 | torch.dot(torch.flatten(grad_i), torch.flatten(grad_j))
505 | for grad_i, grad_j in zip(g_i, g_j)
506 | ]
507 | )
508 | if g_i_g_j < 0:
509 | g_j_norm_square = (
510 | torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2
511 | )
512 | for grad_i, grad_j in zip(g_i, g_j):
513 | grad_i -= g_i_g_j * grad_j / g_j_norm_square
514 |
515 | merged_grad = [sum(g) for g in zip(*pc_grad)]
516 | if self.reduction == "mean":
517 | merged_grad = [g / self.n_tasks for g in merged_grad]
518 |
519 | return merged_grad
520 |
521 | def backward(
522 | self,
523 | losses: torch.Tensor,
524 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
525 | shared_parameters: Union[
526 | List[torch.nn.parameter.Parameter], torch.Tensor
527 | ] = None,
528 | task_specific_parameters: Union[
529 | List[torch.nn.parameter.Parameter], torch.Tensor
530 | ] = None,
531 | **kwargs,
532 | ):
533 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters)
534 | return None, {} # NOTE: to align with all other weight methods
535 |
536 |
537 | class CAGrad(WeightMethod):
538 | def __init__(self, n_tasks, device: torch.device, c=0.4):
539 | super().__init__(n_tasks, device=device)
540 | self.c = c
541 |
542 | def get_weighted_loss(
543 | self,
544 | losses,
545 | shared_parameters,
546 | **kwargs,
547 | ):
548 | """
549 | Parameters
550 | ----------
551 | losses :
552 | shared_parameters : shared parameters
553 | kwargs :
554 | Returns
555 | -------
556 | """
557 | # NOTE: we allow only shared params for now. Need to see paper for other options.
558 | grad_dims = []
559 | for param in shared_parameters:
560 | grad_dims.append(param.data.numel())
561 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device)
562 |
563 | for i in range(self.n_tasks):
564 | if i < (self.n_tasks - 1):
565 | losses[i].backward(retain_graph=True)
566 | else:
567 | losses[i].backward()
568 | self.grad2vec(shared_parameters, grads, grad_dims, i)
569 | # multi_task_model.zero_grad_shared_modules()
570 | for p in shared_parameters:
571 | p.grad = None
572 |
573 | g = self.cagrad(grads, alpha=self.c, rescale=1)
574 | self.overwrite_grad(shared_parameters, g, grad_dims)
575 |
576 | def cagrad(self, grads, alpha=0.5, rescale=1):
577 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks]
578 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient
579 |
580 | x_start = np.ones(self.n_tasks) / self.n_tasks
581 | bnds = tuple((0, 1) for x in x_start)
582 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)}
583 | A = GG.numpy()
584 | b = x_start.copy()
585 | c = (alpha * g0_norm + 1e-8).item()
586 |
587 | def objfn(x):
588 | return (
589 | x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1))
590 | + c
591 | * np.sqrt(
592 | x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1))
593 | + 1e-8
594 | )
595 | ).sum()
596 |
597 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
598 | w_cpu = res.x
599 | ww = torch.Tensor(w_cpu).to(grads.device)
600 | gw = (grads * ww.view(1, -1)).sum(1)
601 | gw_norm = gw.norm()
602 | lmbda = c / (gw_norm + 1e-8)
603 | g = grads.mean(1) + lmbda * gw
604 | if rescale == 0:
605 | return g
606 | elif rescale == 1:
607 | return g / (1 + alpha ** 2)
608 | else:
609 | return g / (1 + alpha)
610 |
611 | @staticmethod
612 | def grad2vec(shared_params, grads, grad_dims, task):
613 | # store the gradients
614 | grads[:, task].fill_(0.0)
615 | cnt = 0
616 | # for mm in m.shared_modules():
617 | # for p in mm.parameters():
618 |
619 | for param in shared_params:
620 | grad = param.grad
621 | if grad is not None:
622 | grad_cur = grad.data.detach().clone()
623 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
624 | en = sum(grad_dims[: cnt + 1])
625 | grads[beg:en, task].copy_(grad_cur.data.view(-1))
626 | cnt += 1
627 |
628 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims):
629 | newgrad = newgrad * self.n_tasks # to match the sum loss
630 | cnt = 0
631 |
632 | # for mm in m.shared_modules():
633 | # for param in mm.parameters():
634 | for param in shared_parameters:
635 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
636 | en = sum(grad_dims[: cnt + 1])
637 | this_grad = newgrad[beg:en].contiguous().view(param.data.size())
638 | param.grad = this_grad.data.clone()
639 | cnt += 1
640 |
641 | def backward(
642 | self,
643 | losses: torch.Tensor,
644 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
645 | shared_parameters: Union[
646 | List[torch.nn.parameter.Parameter], torch.Tensor
647 | ] = None,
648 | task_specific_parameters: Union[
649 | List[torch.nn.parameter.Parameter], torch.Tensor
650 | ] = None,
651 | **kwargs,
652 | ):
653 | self.get_weighted_loss(losses, shared_parameters)
654 | return None, {} # NOTE: to align with all other weight methods
655 |
656 |
657 | class RLW(WeightMethod):
658 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf"""
659 |
660 | def __init__(self, n_tasks, device: torch.device):
661 | super().__init__(n_tasks, device=device)
662 |
663 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
664 | assert len(losses) == self.n_tasks
665 | weight = (F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
666 | loss = torch.sum(losses * weight)
667 |
668 | return loss, dict(weights=weight)
669 |
670 |
671 | class IMTLG(WeightMethod):
672 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr"""
673 |
674 | def __init__(self, n_tasks, device: torch.device):
675 | super().__init__(n_tasks, device=device)
676 |
677 | def get_weighted_loss(
678 | self,
679 | losses,
680 | shared_parameters,
681 | **kwargs,
682 | ):
683 | grads = {}
684 | norm_grads = {}
685 |
686 | for i, loss in enumerate(losses):
687 | g = list(
688 | torch.autograd.grad(
689 | loss,
690 | shared_parameters,
691 | retain_graph=True,
692 | )
693 | )
694 | grad = torch.cat([torch.flatten(grad) for grad in g])
695 | norm_term = torch.norm(grad)
696 |
697 | grads[i] = grad
698 | norm_grads[i] = grad / norm_term
699 |
700 | G = torch.stack(tuple(v for v in grads.values()))
701 | D = (
702 | G[
703 | 0,
704 | ]
705 | - G[
706 | 1:,
707 | ]
708 | )
709 |
710 | U = torch.stack(tuple(v for v in norm_grads.values()))
711 | U = (
712 | U[
713 | 0,
714 | ]
715 | - U[
716 | 1:,
717 | ]
718 | )
719 | first_element = torch.matmul(
720 | G[
721 | 0,
722 | ],
723 | U.t(),
724 | )
725 | try:
726 | second_element = torch.inverse(torch.matmul(D, U.t()))
727 | except:
728 | # workaround for cases where matrix is singular
729 | second_element = torch.inverse(
730 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8
731 | + torch.matmul(D, U.t())
732 | )
733 |
734 | alpha_ = torch.matmul(first_element, second_element)
735 | alpha = torch.cat(
736 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_)
737 | )
738 |
739 | loss = torch.sum(losses * alpha)
740 |
741 | return loss, dict(weights=alpha)
742 |
743 |
744 | class DynamicWeightAverage(WeightMethod):
745 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`.
746 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242
747 | """
748 |
749 | def __init__(
750 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0
751 | ):
752 | """
753 |
754 | Parameters
755 | ----------
756 | n_tasks :
757 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses
758 | temp :
759 | """
760 | super().__init__(n_tasks, device=device)
761 | self.iteration_window = iteration_window
762 | self.temp = temp
763 | self.running_iterations = 0
764 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32)
765 | self.weights = np.ones(n_tasks, dtype=np.float32)
766 |
767 | def get_weighted_loss(self, losses, **kwargs):
768 |
769 | cost = losses.detach().cpu().numpy()
770 |
771 | # update costs - fifo
772 | self.costs[:-1, :] = self.costs[1:, :]
773 | self.costs[-1, :] = cost
774 |
775 | if self.running_iterations > self.iteration_window:
776 | ws = self.costs[self.iteration_window :, :].mean(0) / self.costs[
777 | : self.iteration_window, :
778 | ].mean(0)
779 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (
780 | np.exp(ws / self.temp)
781 | ).sum()
782 |
783 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(
784 | losses.device
785 | )
786 | loss = (task_weights * losses).mean()
787 |
788 | self.running_iterations += 1
789 |
790 | return loss, dict(weights=task_weights)
791 |
792 |
793 | class WeightMethods:
794 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs):
795 | """
796 | :param method:
797 | """
798 | assert method in list(METHODS.keys()), f"unknown method {method}."
799 |
800 | self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs)
801 |
802 | def get_weighted_loss(self, losses, **kwargs):
803 | return self.method.get_weighted_loss(losses, **kwargs)
804 |
805 | def backward(
806 | self, losses, **kwargs
807 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
808 | return self.method.backward(losses, **kwargs)
809 |
810 | def __ceil__(self, losses, **kwargs):
811 | return self.backward(losses, **kwargs)
812 |
813 | def parameters(self):
814 | return self.method.parameters()
815 |
816 |
817 | METHODS = dict(
818 | stl=STL,
819 | ls=LinearScalarization,
820 | uw=Uncertainty,
821 | pcgrad=PCGrad,
822 | mgda=MGDA,
823 | cagrad=CAGrad,
824 | nashmtl=NashMTL,
825 | scaleinvls=ScaleInvariantLinearScalarization,
826 | rlw=RLW,
827 | imtl=IMTLG,
828 | dwa=DynamicWeightAverage,
829 | )
830 |
--------------------------------------------------------------------------------