├── experiments
├── __init__.py
├── nyuv2
│ ├── __init__.py
│ ├── .DS_Store
│ ├── README.md
│ ├── utils.py
│ ├── data.py
│ ├── trainer.py
│ └── models.py
├── quantum_chemistry
│ ├── __init__.py
│ ├── README.md
│ ├── models.py
│ ├── utils.py
│ └── trainer.py
├── .DS_Store
└── utils.py
├── .DS_Store
├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── IGB4MTL.iml
├── requirements.txt
├── methods
├── __init__.py
├── weight_method.py
├── min_norm_solvers.py
├── SAC_Agent.py
├── loss_weight_methods.py
└── gradient_weight_methods.py
├── LICENSE
└── README.md
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/nyuv2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/experiments/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/experiments/.DS_Store
--------------------------------------------------------------------------------
/experiments/nyuv2/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/experiments/nyuv2/.DS_Store
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.2.1
2 | numpy>=1.18.2
3 | torch>=1.4.0
4 | torchvision>=0.8.0
5 | cvxpy
6 | tqdm>=4.45.0
7 | pandas
8 | scikit-learn
9 | seaborn
10 | plotly
11 | scipy==1.10.1
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/README.md:
--------------------------------------------------------------------------------
1 | # QM9 Experiment
2 |
3 | Modification of the code in [Nash-MTL](https://github.com/AvivNavon/nash-mtl).
4 |
5 | ## Dataset
6 |
7 | The dataset will be downloaded automatically from torch_geometric and saved in `./dataset`.
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/experiments/nyuv2/README.md:
--------------------------------------------------------------------------------
1 | # NYUv2 Experiment
2 |
3 | Modification of the code in [Nash-MTL](https://github.com/AvivNavon/nash-mtl).
4 |
5 | ## Dataset
6 |
7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `./dataset/train` and `./dataset/val`.
--------------------------------------------------------------------------------
/.idea/IGB4MTL.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from methods.loss_weight_methods import (
2 | LOSS_METHODS,
3 | STL,
4 | LinearScalarization,
5 | Uncertainty,
6 | UncertaintyLog,
7 | ScaleInvariantLinearScalarization,
8 | RLW,
9 | RLWLog,
10 | DynamicWeightAverage,
11 | DynamicWeightAverageLog,
12 | ImprovableGapBalancing_v1,
13 | ImprovableGapBalancing_v2,
14 | )
15 | from methods.gradient_weight_methods import (
16 | GRADIENT_METHODS,
17 | PCGrad,
18 | MGDA,
19 | CAGrad,
20 | NashMTL,
21 | IMTLG,
22 | )
23 | from methods.SAC_Agent import SAC_Agent, RandomBuffer
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yanqi Dai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/models.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import Iterator
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.nn import GRU, Linear, ReLU, Sequential
7 | from torch_geometric.nn import DimeNet, NNConv, Set2Set, radius_graph
8 |
9 |
10 | class Net(torch.nn.Module):
11 | def __init__(self, n_tasks, num_features=11, dim=64):
12 | super().__init__()
13 | self.n_tasks = n_tasks
14 | self.dim = dim
15 | self.lin0 = torch.nn.Linear(num_features, dim)
16 |
17 | nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
18 | self.conv = NNConv(dim, dim, nn, aggr="mean")
19 | self.gru = GRU(dim, dim)
20 |
21 | self.set2set = Set2Set(dim, processing_steps=3)
22 | self.lin1 = torch.nn.Linear(2 * dim, dim)
23 |
24 | self._init_task_heads()
25 |
26 | def _init_task_heads(self):
27 | for i in range(self.n_tasks):
28 | setattr(self, f"head_{i}", torch.nn.Linear(self.dim, 1))
29 | self.task_specific = torch.nn.ModuleList(
30 | [getattr(self, f"head_{i}") for i in range(self.n_tasks)]
31 | )
32 |
33 | def forward(self, data, return_representation=False):
34 | out = F.relu(self.lin0(data.x))
35 | h = out.unsqueeze(0)
36 |
37 | for i in range(3):
38 | m = F.relu(self.conv(out, data.edge_index, data.edge_attr))
39 | out, h = self.gru(m.unsqueeze(0), h)
40 | out = out.squeeze(0)
41 |
42 | out = self.set2set(out, data.batch)
43 | features = F.relu(self.lin1(out))
44 | logits = torch.cat(
45 | [getattr(self, f"head_{i}")(features) for i in range(self.n_tasks)], dim=1
46 | )
47 | if return_representation:
48 | return logits, features
49 | return logits
50 |
51 | def shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
52 | return chain(
53 | self.lin0.parameters(),
54 | self.conv.parameters(),
55 | self.gru.parameters(),
56 | self.set2set.parameters(),
57 | self.lin1.parameters(),
58 | )
59 |
60 | def task_specific_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
61 | return self.task_specific.parameters()
62 |
63 | def last_shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
64 | return self.lin1.parameters()
65 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.utils import remove_self_loops
4 |
5 |
6 | class MyTransform(object):
7 | def __init__(self, target: list = None):
8 | if target is None:
9 | target = torch.tensor([0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]) # removing 4
10 | else:
11 | target = torch.tensor(target)
12 | self.target = target
13 |
14 | def __call__(self, data):
15 | # Specify target.
16 | data.y = data.y[:, self.target]
17 | return data
18 |
19 |
20 | class Complete(object):
21 | def __call__(self, data):
22 | device = data.edge_index.device
23 |
24 | row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
25 | col = torch.arange(data.num_nodes, dtype=torch.long, device=device)
26 |
27 | row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
28 | col = col.repeat(data.num_nodes)
29 | edge_index = torch.stack([row, col], dim=0)
30 |
31 | edge_attr = None
32 | if data.edge_attr is not None:
33 | idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
34 | size = list(data.edge_attr.size())
35 | size[0] = data.num_nodes * data.num_nodes
36 | edge_attr = data.edge_attr.new_zeros(size)
37 | edge_attr[idx] = data.edge_attr
38 |
39 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
40 | data.edge_attr = edge_attr
41 | data.edge_index = edge_index
42 |
43 | return data
44 |
45 |
46 | qm9_target_dict = {
47 | 0: "mu",
48 | 1: "alpha",
49 | 2: "homo",
50 | 3: "lumo",
51 | 5: "r2",
52 | 6: "zpve",
53 | 7: "U0",
54 | 8: "U",
55 | 9: "H",
56 | 10: "G",
57 | 11: "Cv",
58 | }
59 |
60 | # for \Delta_m calculations
61 | # -------------------------
62 | # DimeNet uses the atomization energy for targets U0, U, H, and G.
63 | target_idx = [0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]
64 |
65 | # Report meV instead of eV.
66 | multiply_indx = [2, 3, 5, 6, 7, 8, 9]
67 |
68 | n_tasks = len(target_idx)
69 |
70 | # stl results
71 | BASE = np.array(
72 | [
73 | 0.0671,
74 | 0.1814,
75 | 60.576,
76 | 53.915,
77 | 0.5027,
78 | 4.539,
79 | 58.838,
80 | 64.244,
81 | 63.852,
82 | 66.223,
83 | 0.07212,
84 | ]
85 | )
86 |
87 | SIGN = np.array([0] * n_tasks)
88 | KK = np.ones(n_tasks) * -1
89 |
90 |
91 | def delta_fn(a):
92 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # *100 for percentage
93 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
31 |
32 |
33 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/experiments/nyuv2/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ConfMatrix(object):
6 | def __init__(self, num_classes):
7 | self.num_classes = num_classes
8 | self.mat = None
9 |
10 | def update(self, pred, target):
11 | n = self.num_classes
12 | if self.mat is None:
13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
14 | with torch.no_grad():
15 | k = (target >= 0) & (target < n)
16 | inds = n * target[k].to(torch.int64) + pred[k]
17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
18 |
19 | def get_metrics(self):
20 | h = self.mat.float()
21 | acc = torch.diag(h).sum() / h.sum()
22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy()
24 |
25 |
26 | def depth_error(x_pred, x_output):
27 | device = x_pred.device
28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
29 | x_pred_true = x_pred.masked_select(binary_mask)
30 | x_output_true = x_output.masked_select(binary_mask)
31 | abs_err = torch.abs(x_pred_true - x_output_true)
32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
33 | return (
34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
35 | ).item(), (
36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)
37 | ).item()
38 |
39 |
40 | def normal_error(x_pred, x_output):
41 | binary_mask = torch.sum(x_output, dim=1) != 0
42 | error = (
43 | torch.acos(
44 | torch.clamp(
45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1
46 | )
47 | )
48 | .detach()
49 | .cpu()
50 | .numpy()
51 | )
52 | error = np.degrees(error)
53 | return (
54 | np.mean(error),
55 | np.median(error),
56 | np.mean(error < 11.25),
57 | np.mean(error < 22.5),
58 | np.mean(error < 30),
59 | )
60 |
61 |
62 | # for calculating \Delta_m
63 | delta_stats = [
64 | "mean iou",
65 | "pix acc",
66 | "abs err",
67 | "rel err",
68 | "mean",
69 | "median",
70 | "<11.25",
71 | "<22.5",
72 | "<30",
73 | ]
74 | BASE = np.array(
75 | [0.4116, 0.657, 0.6074, 0.24, 24.49, 18.24, 0.3192, 0.5916, 0.7056]
76 | ) # base results of STL
77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1])
78 | KK = np.ones(9) * -1
79 |
80 |
81 | def delta_fn(a):
82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage
83 |
84 |
85 | M_NUM = np.array([[0, 2], [2, 4], [4, 9]])
86 |
87 |
88 | def stl_eval_mean(a, main_task):
89 | a = KK ** SIGN * a
90 | eval_values = np.array([a[bound[0]: bound[1]].mean() for bound in M_NUM])
91 | return eval_values[main_task]
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IGB4MTL
2 |
3 | Official implementation of _"Improvable Gap Balancing for Multi-Task Learning"_, which has been accepted to UAI 2023.
4 |
5 | ## Setup environment
6 |
7 | ```bash
8 | conda create -n igb4mtl python=3.8.13
9 | conda activate igb4mtl
10 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch
11 | ```
12 |
13 | Install the repo:
14 |
15 | ```bash
16 | git clone https://github.com/YanqiDai/IGB4MTL.git
17 | cd IGB4MTL
18 | pip install -r requirement.txt
19 | ```
20 |
21 | ## Run experiment
22 |
23 | Follow instruction on the experiment README file for more information regarding, e.g., datasets.
24 |
25 | We support our IGB methods and other existing MTL methods with a unified API. To run experiments:
26 |
27 | ```bash
28 | cd experiments/
29 | python trainer.py --loss_method= --gradient_method=
30 | ```
31 |
32 | Here,
33 | - `` is one of `[quantum_chemistry, nyuv2]`.
34 | - `` is one of `igbv1`, `igbv2` and the following loss balancing MTL methods.
35 | - `` is one of the following gradient balancing MTL methods.
36 | - Both `` and `` are optional:
37 | - only using `` is to run a loss balancing method;
38 | - only using `` is to run a gradient balancing method;
39 | - using neither is to run Equal Weighting (EW) method.
40 | - using both is to run a combined MTL method by both loss balancing and gradient balancing.
41 |
42 | ## MTL methods
43 |
44 | We support the following loss balancing and gradient balancing methods.
45 |
46 | | Loss Balancing Method (code name) | Paper (notes) |
47 | |:-------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------:|
48 | | Equal Weighting (`ls`) | - (linear scalarization) |
49 | | Random Loss Weighting (`rlw`) | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) |
50 | | Dynamic Weight Average (`dwa`) | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/abs/1803.10704) |
51 | | Uncertainty Weighting (`uw`) | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) |
52 | | Improvable Gap Balancing v1 (`igbv1`) | - (our first IGB method) |
53 | | Improvable Gap Balancing v1 (`igbv1`) | - (our second IGB method) |
54 |
55 |
56 | | Gradient Balancing Method (code name) | Paper (notes) |
57 | |:-------------------------------------:|:------------------------------------------------------------------------------------------------:|
58 | | MGDA (`mgda`) | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/abs/1810.04650) |
59 | | PCGrad (`pcgrad`) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/abs/2001.06782) |
60 | | CAGrad (`cagrad`) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) |
61 | | IMTL-G (`imtl`) | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) |
62 | | Nash-MTL (`nashmtl`) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) |
63 |
--------------------------------------------------------------------------------
/experiments/nyuv2/data.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data.dataset import Dataset
9 |
10 | """
11 | Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py
12 | """
13 |
14 |
15 | class RandomScaleCrop(object):
16 | """
17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
18 | """
19 | def __init__(self, scale=None):
20 | if scale is None:
21 | scale = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
22 | self.scale = scale
23 |
24 | def __call__(self, img, label, depth, normal):
25 | height, width = img.shape[-2:]
26 | sc = self.scale[random.randint(0, len(self.scale) - 1)]
27 | h, w = int(height / sc), int(width / sc)
28 | i = random.randint(0, height - h)
29 | j = random.randint(0, width - w)
30 | img_ = F.interpolate(
31 | img[None, :, i: i + h, j: j + w],
32 | size=(height, width),
33 | mode="bilinear",
34 | align_corners=True,
35 | ).squeeze(0)
36 | label_ = (
37 | F.interpolate(
38 | label[None, None, i: i + h, j: j + w],
39 | size=(height, width),
40 | mode="nearest",
41 | )
42 | .squeeze(0)
43 | .squeeze(0)
44 | )
45 | depth_ = F.interpolate(
46 | depth[None, :, i: i + h, j: j + w], size=(height, width), mode="nearest"
47 | ).squeeze(0)
48 | normal_ = F.interpolate(
49 | normal[None, :, i: i + h, j: j + w],
50 | size=(height, width),
51 | mode="bilinear",
52 | align_corners=True,
53 | ).squeeze(0)
54 | return img_, label_, depth_ / sc, normal_
55 |
56 |
57 | class NYUv2(Dataset):
58 | def __init__(self, root, mode="train", augmentation=False):
59 | self.mode = mode
60 | self.root = os.path.expanduser(root)
61 | self.augmentation = augmentation
62 |
63 | # read the data file
64 | if mode == "train":
65 | self.data_path = root + "/train"
66 | elif mode == "val":
67 | self.data_path = root + "/val"
68 | else:
69 | self.data_path = root + "/test"
70 |
71 | # get data_files and calculate data length
72 | self.data_files = fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
73 | self.data_len = len(self.data_files)
74 |
75 | def __getitem__(self, index):
76 | # load data from the pre-processed npy files
77 | image = torch.from_numpy(
78 | np.moveaxis(
79 | np.load(self.data_path + "/image/{}".format(self.data_files[index])), -1, 0
80 | )
81 | )
82 | semantic = torch.from_numpy(
83 | np.load(self.data_path + "/label/{}".format(self.data_files[index]))
84 | )
85 | depth = torch.from_numpy(
86 | np.moveaxis(
87 | np.load(self.data_path + "/depth/{}".format(self.data_files[index])), -1, 0
88 | )
89 | )
90 | normal = torch.from_numpy(
91 | np.moveaxis(
92 | np.load(self.data_path + "/normal/{}".format(self.data_files[index])), -1, 0
93 | )
94 | )
95 |
96 | # apply data augmentation if required
97 | if self.augmentation:
98 | image, semantic, depth, normal = RandomScaleCrop()(
99 | image, semantic, depth, normal
100 | )
101 | if torch.rand(1) < 0.5:
102 | image = torch.flip(image, dims=[2])
103 | semantic = torch.flip(semantic, dims=[1])
104 | depth = torch.flip(depth, dims=[2])
105 | normal = torch.flip(normal, dims=[2])
106 | normal[0, :, :] = -normal[0, :, :]
107 |
108 | return image.float(), semantic.float(), depth.float(), normal.float()
109 |
110 | def __len__(self):
111 | return self.data_len
112 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import random
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from methods import LOSS_METHODS, GRADIENT_METHODS
11 |
12 |
13 | def str_to_list(string):
14 | return [float(s) for s in string.split(",")]
15 |
16 |
17 | def str_or_float(value):
18 | try:
19 | return float(value)
20 | except:
21 | return value
22 |
23 |
24 | def str2bool(v):
25 | if isinstance(v, bool):
26 | return v
27 | if v.lower() in ("yes", "true", "t", "y", "1"):
28 | return True
29 | elif v.lower() in ("no", "false", "f", "n", "0"):
30 | return False
31 | else:
32 | raise argparse.ArgumentTypeError("Boolean value expected.")
33 |
34 |
35 | common_parser = argparse.ArgumentParser(add_help=False)
36 | common_parser.add_argument("--data-path", type=Path, help="path to data")
37 | common_parser.add_argument("--n-epochs", type=int, default=500)
38 | common_parser.add_argument("--batch-size", type=int, default=2, help="batch size")
39 | common_parser.add_argument(
40 | "--loss_method",
41 | type=str,
42 | choices=list(LOSS_METHODS.keys()),
43 | default="ls",
44 | help="MTL loss weight method"
45 | )
46 | common_parser.add_argument(
47 | "--gradient_method",
48 | type=str,
49 | choices=list(GRADIENT_METHODS.keys()),
50 | default="ls",
51 | help="MTL gradient weight method"
52 | )
53 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
54 | common_parser.add_argument(
55 | "--method-params-lr",
56 | type=float,
57 | default=0.025,
58 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting",
59 | )
60 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
61 | common_parser.add_argument("--seed", type=int, default=42, help="seed value")
62 | # NashMTL
63 | common_parser.add_argument(
64 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations"
65 | )
66 | common_parser.add_argument(
67 | "--update-weights-every",
68 | type=int,
69 | default=1,
70 | help="update task weights every x iterations.",
71 | )
72 | # stl
73 | common_parser.add_argument(
74 | "--main-task",
75 | type=int,
76 | default=0,
77 | help="main task for stl. Ignored if method != stl",
78 | )
79 | # cagrad
80 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.")
81 | # dwa
82 | common_parser.add_argument(
83 | "--dwa-temp",
84 | type=float,
85 | default=2.0,
86 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.",
87 | )
88 |
89 | # igbv1 and igbv2
90 | common_parser.add_argument(
91 | "--base_epoch",
92 | type=int,
93 | default=1,
94 | help="Set which epoch's average losses as base_losses for fw or fwlog",
95 | )
96 |
97 | # igbv2
98 | common_parser.add_argument(
99 | "--sac_lr",
100 | type=float,
101 | default=3e-4,
102 | help="learning rate of sac in igbv2",
103 | )
104 | common_parser.add_argument(
105 | "--buffer_size",
106 | type=float,
107 | default=1e4,
108 | help="max replay buffer size in igbv2",
109 | )
110 |
111 |
112 | def count_parameters(model):
113 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
114 |
115 |
116 | def set_logger():
117 | logging.basicConfig(
118 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
119 | level=logging.INFO,
120 | )
121 |
122 |
123 | def set_seed(seed):
124 | """for reproducibility
125 | :param seed:
126 | :return:
127 | """
128 | np.random.seed(seed)
129 | random.seed(seed)
130 |
131 | torch.manual_seed(seed)
132 | torch.cuda.manual_seed(seed)
133 | torch.cuda.manual_seed_all(seed)
134 |
135 | torch.backends.cudnn.enabled = False
136 | torch.backends.cudnn.benchmark = False
137 | torch.backends.cudnn.deterministic = True
138 |
139 |
140 | def get_device(no_cuda=False, gpus="0"):
141 | return torch.device(
142 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu"
143 | )
144 |
145 |
146 | def extract_weight_method_parameters_from_args(args):
147 | weight_methods_parameters = defaultdict(dict)
148 | weight_methods_parameters.update(
149 | dict(
150 | nashmtl=dict(
151 | update_weights_every=args.update_weights_every,
152 | optim_niter=args.nashmtl_optim_niter,
153 | ),
154 | stl=dict(main_task=args.main_task),
155 | cagrad=dict(c=args.c),
156 | dwa=dict(temp=args.dwa_temp),
157 | igbv2=dict(sac_lr=args.sac_lr, buffer_size=int(args.buffer_size)),
158 | )
159 | )
160 | return weight_methods_parameters
161 |
--------------------------------------------------------------------------------
/methods/weight_method.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from abc import abstractmethod
4 | from typing import Dict, List, Tuple, Union
5 |
6 | import cvxpy as cp
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from scipy.optimize import minimize
11 |
12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers
13 |
14 |
15 | class WeightMethod:
16 | def __init__(self, n_tasks: int, device: torch.device):
17 | super().__init__()
18 | self.n_tasks = n_tasks
19 | self.device = device
20 |
21 | @abstractmethod
22 | def get_weighted_loss(
23 | self,
24 | losses: torch.Tensor,
25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
26 | task_specific_parameters: Union[
27 | List[torch.nn.parameter.Parameter], torch.Tensor
28 | ],
29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
31 | **kwargs,
32 | ):
33 | pass
34 |
35 | @abstractmethod
36 | def get_weighted_losses(
37 | self,
38 | losses: torch.Tensor,
39 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
40 | task_specific_parameters: Union[
41 | List[torch.nn.parameter.Parameter], torch.Tensor
42 | ],
43 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
44 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
45 | **kwargs,
46 | ):
47 | pass
48 |
49 | def backward(
50 | self,
51 | losses: torch.Tensor,
52 | shared_parameters: Union[
53 | List[torch.nn.parameter.Parameter], torch.Tensor
54 | ] = None,
55 | task_specific_parameters: Union[
56 | List[torch.nn.parameter.Parameter], torch.Tensor
57 | ] = None,
58 | last_shared_parameters: Union[
59 | List[torch.nn.parameter.Parameter], torch.Tensor
60 | ] = None,
61 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
62 | **kwargs,
63 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
64 | """
65 |
66 | Parameters
67 | ----------
68 | losses :
69 | shared_parameters :
70 | task_specific_parameters :
71 | last_shared_parameters : parameters of last shared layer/block
72 | representation : shared representation
73 | kwargs :
74 |
75 | Returns
76 | -------
77 | Loss, extra outputs
78 | """
79 | loss, extra_outputs = self.get_weighted_loss(
80 | losses=losses,
81 | shared_parameters=shared_parameters,
82 | task_specific_parameters=task_specific_parameters,
83 | last_shared_parameters=last_shared_parameters,
84 | representation=representation,
85 | **kwargs,
86 | )
87 | loss.backward()
88 | return loss, extra_outputs
89 |
90 | def __call__(
91 | self,
92 | losses: torch.Tensor,
93 | shared_parameters: Union[
94 | List[torch.nn.parameter.Parameter], torch.Tensor
95 | ] = None,
96 | task_specific_parameters: Union[
97 | List[torch.nn.parameter.Parameter], torch.Tensor
98 | ] = None,
99 | **kwargs,
100 | ):
101 | return self.backward(
102 | losses=losses,
103 | shared_parameters=shared_parameters,
104 | task_specific_parameters=task_specific_parameters,
105 | **kwargs,
106 | )
107 |
108 | def parameters(self) -> List[torch.Tensor]:
109 | """return learnable parameters"""
110 | return []
111 |
112 |
113 | class LinearScalarization(WeightMethod):
114 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h"""
115 |
116 | def __init__(
117 | self,
118 | n_tasks: int,
119 | device: torch.device,
120 | task_weights: Union[List[float], torch.Tensor] = None,
121 | ):
122 | super().__init__(n_tasks, device=device)
123 | if task_weights is None:
124 | task_weights = torch.ones((n_tasks,))
125 | if not isinstance(task_weights, torch.Tensor):
126 | task_weights = torch.tensor(task_weights)
127 | assert len(task_weights) == n_tasks
128 | self.task_weights = task_weights.to(device)
129 |
130 | def get_weighted_loss(self, losses, **kwargs):
131 | loss = torch.sum(losses * self.task_weights)
132 | return loss, dict(weights=self.task_weights)
133 |
134 | def get_weighted_losses(self, losses, **kwargs):
135 | losses = losses * self.task_weights
136 | return losses, dict(weights=self.task_weights)
137 |
--------------------------------------------------------------------------------
/methods/min_norm_solvers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | # This code is from
6 | # Multi-Task Learning as Multi-Objective Optimization
7 | # Ozan Sener, Vladlen Koltun
8 | # Neural Information Processing Systems (NeurIPS) 2018
9 | # https://github.com/intel-isl/MultiObjectiveOptimization
10 | class MinNormSolver:
11 | MAX_ITER = 250
12 | STOP_CRIT = 1e-5
13 |
14 | @staticmethod
15 | def _min_norm_element_from2(v1v1, v1v2, v2v2):
16 | """
17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
18 | d is the distance (objective) optimzed
19 | v1v1 =
20 | v1v2 =
21 | v2v2 =
22 | """
23 | if v1v2 >= v1v1:
24 | # Case: Fig 1, third column
25 | gamma = 0.999
26 | cost = v1v1
27 | return gamma, cost
28 | if v1v2 >= v2v2:
29 | # Case: Fig 1, first column
30 | gamma = 0.001
31 | cost = v2v2
32 | return gamma, cost
33 | # Case: Fig 1, second column
34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
35 | cost = v2v2 + gamma * (v1v2 - v2v2)
36 | return gamma, cost
37 |
38 | @staticmethod
39 | def _min_norm_2d(vecs, dps):
40 | """
41 | Find the minimum norm solution as combination of two points
42 | This is correct only in 2D
43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
44 | """
45 | dmin = 1e8
46 | for i in range(len(vecs)):
47 | for j in range(i + 1, len(vecs)):
48 | if (i, j) not in dps:
49 | dps[(i, j)] = 0.0
50 | for k in range(len(vecs[i])):
51 | dps[(i, j)] += torch.dot(
52 | vecs[i][k], vecs[j][k]
53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0]
54 | dps[(j, i)] = dps[(i, j)]
55 | if (i, i) not in dps:
56 | dps[(i, i)] = 0.0
57 | for k in range(len(vecs[i])):
58 | dps[(i, i)] += torch.dot(
59 | vecs[i][k], vecs[i][k]
60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0]
61 | if (j, j) not in dps:
62 | dps[(j, j)] = 0.0
63 | for k in range(len(vecs[i])):
64 | dps[(j, j)] += torch.dot(
65 | vecs[j][k], vecs[j][k]
66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0]
67 | c, d = MinNormSolver._min_norm_element_from2(
68 | dps[(i, i)], dps[(i, j)], dps[(j, j)]
69 | )
70 | if d < dmin:
71 | dmin = d
72 | sol = [(i, j), c, d]
73 | return sol, dps
74 |
75 | @staticmethod
76 | def _projection2simplex(y):
77 | """
78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
79 | """
80 | m = len(y)
81 | sorted_y = np.flip(np.sort(y), axis=0)
82 | tmpsum = 0.0
83 | tmax_f = (np.sum(y) - 1.0) / m
84 | for i in range(m - 1):
85 | tmpsum += sorted_y[i]
86 | tmax = (tmpsum - 1) / (i + 1.0)
87 | if tmax > sorted_y[i + 1]:
88 | tmax_f = tmax
89 | break
90 | return np.maximum(y - tmax_f, np.zeros(y.shape))
91 |
92 | @staticmethod
93 | def _next_point(cur_val, grad, n):
94 | proj_grad = grad - (np.sum(grad) / n)
95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
97 |
98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7)
99 | t = 1
100 | if len(tm1[tm1 > 1e-7]) > 0:
101 | t = np.min(tm1[tm1 > 1e-7])
102 | if len(tm2[tm2 > 1e-7]) > 0:
103 | t = min(t, np.min(tm2[tm2 > 1e-7]))
104 |
105 | next_point = proj_grad * t + cur_val
106 | next_point = MinNormSolver._projection2simplex(next_point)
107 | return next_point
108 |
109 | @staticmethod
110 | def find_min_norm_element(vecs):
111 | """
112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
116 | """
117 | # Solution lying at the combination of two points
118 | dps = {}
119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
120 |
121 | n = len(vecs)
122 | sol_vec = np.zeros(n)
123 | sol_vec[init_sol[0][0]] = init_sol[1]
124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
125 |
126 | if n < 3:
127 | # This is optimal for n=2, so return the solution
128 | return sol_vec, init_sol[2]
129 |
130 | iter_count = 0
131 |
132 | grad_mat = np.zeros((n, n))
133 | for i in range(n):
134 | for j in range(n):
135 | grad_mat[i, j] = dps[(i, j)]
136 |
137 | while iter_count < MinNormSolver.MAX_ITER:
138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
140 | # Re-compute the inner products for line search
141 | v1v1 = 0.0
142 | v1v2 = 0.0
143 | v2v2 = 0.0
144 | for i in range(n):
145 | for j in range(n):
146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point
151 | change = new_sol_vec - sol_vec
152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
153 | return sol_vec, nd
154 | sol_vec = new_sol_vec
155 |
156 | @staticmethod
157 | def find_min_norm_element_FW(vecs):
158 | """
159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
163 | """
164 | # Solution lying at the combination of two points
165 | dps = {}
166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
167 |
168 | n = len(vecs)
169 | sol_vec = np.zeros(n)
170 | sol_vec[init_sol[0][0]] = init_sol[1]
171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1]
172 |
173 | if n < 3:
174 | # This is optimal for n=2, so return the solution
175 | return sol_vec, init_sol[2]
176 |
177 | iter_count = 0
178 |
179 | grad_mat = np.zeros((n, n))
180 | for i in range(n):
181 | for j in range(n):
182 | grad_mat[i, j] = dps[(i, j)]
183 |
184 | while iter_count < MinNormSolver.MAX_ITER:
185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec))
186 |
187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
189 | v2v2 = grad_mat[t_iter, t_iter]
190 |
191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
192 | new_sol_vec = nc * sol_vec
193 | new_sol_vec[t_iter] += 1 - nc
194 |
195 | change = new_sol_vec - sol_vec
196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
197 | return sol_vec, nd
198 | sol_vec = new_sol_vec
199 |
200 |
201 | def gradient_normalizers(grads, losses, normalization_type):
202 | gn = {}
203 | if normalization_type == "norm":
204 | for t in grads:
205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
206 | elif normalization_type == "loss":
207 | for t in grads:
208 | gn[t] = losses[t]
209 | elif normalization_type == "loss+":
210 | for t in grads:
211 | gn[t] = losses[t] * np.sqrt(
212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])
213 | )
214 | elif normalization_type == "none":
215 | for t in grads:
216 | gn[t] = 1.0
217 | else:
218 | print("ERROR: Invalid Normalization Type")
219 | return gn
220 |
--------------------------------------------------------------------------------
/experiments/quantum_chemistry/trainer.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torch_geometric.transforms as T
9 | from torch_geometric.datasets import QM9
10 | from torch_geometric.loader import DataLoader
11 | from tqdm import trange
12 |
13 | import sys
14 | sys.path.append("../..")
15 | from experiments.quantum_chemistry.models import Net
16 | from experiments.quantum_chemistry.utils import (
17 | Complete,
18 | MyTransform,
19 | delta_fn,
20 | multiply_indx,
21 | )
22 | from experiments.quantum_chemistry.utils import target_idx as targets
23 | from experiments.utils import (
24 | common_parser,
25 | extract_weight_method_parameters_from_args,
26 | get_device,
27 | set_logger,
28 | set_seed,
29 | str2bool,
30 | )
31 |
32 | from methods.loss_weight_methods import LossWeightMethods
33 | from methods.gradient_weight_methods import GradientWeightMethods
34 |
35 | set_logger()
36 |
37 |
38 | @torch.no_grad()
39 | def evaluate(model, loader, std, scale_target):
40 | model.eval()
41 | data_size = 0.0
42 | task_losses = 0.0
43 | for i, data in enumerate(loader):
44 | data = data.to(device)
45 | out = model(data)
46 | if scale_target:
47 | task_losses += F.l1_loss(
48 | out * std.to(device), data.y * std.to(device), reduction="none"
49 | ).sum(
50 | 0
51 | ) # MAE
52 | else:
53 | task_losses += F.l1_loss(out, data.y, reduction="none").sum(0) # MAE
54 | data_size += len(data.y)
55 |
56 | model.train()
57 |
58 | avg_task_losses = task_losses / data_size
59 |
60 | # Report meV instead of eV.
61 | avg_task_losses = avg_task_losses.detach().cpu().numpy()
62 | avg_task_losses[multiply_indx] *= 1000
63 |
64 | delta_m = delta_fn(avg_task_losses)
65 | return dict(
66 | avg_loss=avg_task_losses.mean(),
67 | avg_task_losses=avg_task_losses,
68 | delta_m=delta_m,
69 | )
70 |
71 |
72 | def main(
73 | data_path: str,
74 | batch_size: int,
75 | device: torch.device,
76 | lr: float,
77 | n_epochs: int,
78 | targets: list = None,
79 | scale_target: bool = True,
80 | main_task: int = None,
81 | ):
82 | timestr = time.strftime("%Y%m%d-%H%M%S")
83 | os.makedirs("./logs", exist_ok=True)
84 | log_file = f"./logs/{timestr}_{args.loss_method}_{args.gradient_method}_seed{args.seed}_log.txt"
85 |
86 | dim = 64
87 | model = Net(n_tasks=len(targets), num_features=11, dim=dim).to(device)
88 |
89 | transform = T.Compose([MyTransform(targets), Complete(), T.Distance(norm=False)])
90 | dataset = QM9(data_path, transform=transform).shuffle()
91 |
92 | # Split datasets.
93 | test_dataset = dataset[:10000]
94 | val_dataset = dataset[10000:20000]
95 | train_dataset = dataset[20000:]
96 |
97 | std = None
98 | if scale_target:
99 | mean = train_dataset.data.y[:, targets].mean(dim=0, keepdim=True)
100 | std = train_dataset.data.y[:, targets].std(dim=0, keepdim=True)
101 |
102 | dataset.data.y[:, targets] = (dataset.data.y[:, targets] - mean) / std
103 |
104 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
105 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
106 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
107 |
108 | loss_weight_methods_parameters = extract_weight_method_parameters_from_args(args)
109 | loss_weight_method = LossWeightMethods(
110 | args.loss_method, n_tasks=len(targets), device=device, **loss_weight_methods_parameters[args.loss_method]
111 | )
112 | # gradient_weight method
113 | gradient_weight_methods_parameters = extract_weight_method_parameters_from_args(args)
114 | gradient_weight_method = GradientWeightMethods(
115 | args.gradient_method, n_tasks=len(targets), device=device, **gradient_weight_methods_parameters[args.gradient_method]
116 | )
117 |
118 | optimizer = torch.optim.Adam(
119 | [
120 | dict(params=model.parameters(), lr=lr),
121 | dict(params=loss_weight_method.parameters(), lr=args.method_params_lr),
122 | dict(params=gradient_weight_method.parameters(), lr=args.method_params_lr),
123 |
124 | ],
125 | )
126 |
127 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
128 | optimizer, mode="min", factor=0.7, patience=5, min_lr=0.00001
129 | )
130 |
131 | epoch_iterator = trange(n_epochs)
132 | train_batch = len(train_loader)
133 |
134 | best_val = np.inf
135 | best_val_delta = np.inf
136 |
137 | train_time_sum = 0.0
138 |
139 | # reward scale for IGBv2
140 | if args.loss_method == 'igbv2':
141 | loss_weight_method.method.train_batch = train_batch
142 |
143 | for epoch in epoch_iterator:
144 | lr = optimizer.param_groups[0]["lr"]
145 | avg_train_losses = torch.zeros(len(targets)).to(device)
146 | avg_loss_weights = torch.zeros(len(targets)).to(device)
147 |
148 | start_train_time = time.time()
149 |
150 | # reward scale for IGBv2
151 | if args.loss_method == 'igbv2':
152 | loss_weight_method.method.reward_scale = lr / optimizer.param_groups[0]['lr']
153 |
154 | for j, data in enumerate(train_loader):
155 | model.train()
156 |
157 | data = data.to(device)
158 | optimizer.zero_grad()
159 |
160 | out, features = model(data, return_representation=True)
161 |
162 | losses = F.mse_loss(out, data.y, reduction="none").mean(0)
163 | # print(losses)
164 | avg_train_losses += losses.detach() / train_batch
165 |
166 | weighted_losses, loss_weights = loss_weight_method.get_weighted_losses(
167 | losses=losses,
168 | shared_parameters=list(model.shared_parameters()),
169 | task_specific_parameters=list(model.task_specific_parameters()),
170 | last_shared_parameters=list(model.last_shared_parameters()),
171 | representation=features,
172 | )
173 | avg_loss_weights += loss_weights['weights'] / train_batch
174 |
175 | loss, gradient_weights = gradient_weight_method.backward(
176 | losses=weighted_losses,
177 | shared_parameters=list(model.shared_parameters()),
178 | task_specific_parameters=list(model.task_specific_parameters()),
179 | last_shared_parameters=list(model.last_shared_parameters()),
180 | representation=features,
181 | )
182 |
183 | optimizer.step()
184 |
185 | epoch_iterator.set_description(
186 | f"[{epoch} {j + 1}/{train_batch}]"
187 | )
188 |
189 | # base_losses for IGBv1 and IGBv2
190 | if 'igb' in args.loss_method and epoch == args.base_epoch:
191 | loss_weight_method.method.base_losses = avg_train_losses
192 |
193 | end_train_time = time.time()
194 | train_time_sum += end_train_time - start_train_time
195 |
196 | val_loss_dict = evaluate(model, val_loader, std=std, scale_target=scale_target)
197 | val_loss = val_loss_dict["avg_loss"]
198 | val_delta = val_loss_dict["delta_m"]
199 |
200 | results = f"Epoch: {epoch:04d}\n" \
201 | f"AVERAGE LOSS WEIGHTS: " \
202 | f"{avg_loss_weights[0]:.4f} {avg_loss_weights[1]:.4f} {avg_loss_weights[2]:.4f} " \
203 | f"{avg_loss_weights[3]:.4f} {avg_loss_weights[4]:.4f} {avg_loss_weights[5]:.4f} " \
204 | f"{avg_loss_weights[6]:.4f} {avg_loss_weights[7]:.4f} {avg_loss_weights[8]:.4f} " \
205 | f"{avg_loss_weights[9]:.4f} {avg_loss_weights[10]:.4f}\n" \
206 | f"TRAIN: {losses.mean().item():.3f}\n" \
207 | f"VAL: {val_loss:.3f} {val_delta:.3f}\n"
208 |
209 | if args.loss_method == "stl":
210 | best_val_criteria = val_loss_dict["avg_task_losses"][main_task] <= best_val
211 | else:
212 | best_val_criteria = val_delta <= best_val_delta
213 |
214 | if best_val_criteria:
215 | best_val = val_loss
216 | best_val_delta = val_delta
217 |
218 | test_loss_dict = evaluate(model, test_loader, std=std, scale_target=scale_target)
219 | test_loss = test_loss_dict["avg_loss"]
220 | test_task_losses = test_loss_dict["avg_task_losses"]
221 | test_delta = test_loss_dict["delta_m"]
222 | test_result = f"TEST: {test_loss:.3f} {test_delta:.3f}\n"
223 | test_result += f"TEST LOSSES: "
224 | for i in range(len(targets)):
225 | test_result += f"{test_task_losses[i]:.3f} "
226 | test_result = test_result[:-1] + "\n"
227 | print(test_result, end='')
228 | results += test_result
229 |
230 | with open(log_file, mode="a") as log_f:
231 | log_f.write(results)
232 |
233 | scheduler.step(
234 | val_loss_dict["avg_task_losses"][main_task]
235 | if args.loss_method == "stl"
236 | else val_delta
237 | )
238 |
239 | train_time_log = f"Training time: {int(train_time_sum)}s\n"
240 | print(train_time_log, end='')
241 | with open(log_file, mode="a") as log_f:
242 | log_f.write(train_time_log)
243 |
244 |
245 | if __name__ == "__main__":
246 | parser = ArgumentParser("QM9", parents=[common_parser])
247 | parser.set_defaults(
248 | data_path="./dataset",
249 | lr=1e-3,
250 | n_epochs=300,
251 | batch_size=120,
252 | )
253 | parser.add_argument("--scale-y", default=True, type=str2bool)
254 | args = parser.parse_args()
255 |
256 | # set seed
257 | set_seed(args.seed)
258 |
259 | device = get_device(gpus=args.gpu)
260 | main(
261 | data_path=args.data_path,
262 | batch_size=args.batch_size,
263 | device=device,
264 | lr=args.lr,
265 | n_epochs=args.n_epochs,
266 | targets=targets,
267 | scale_target=args.scale_y,
268 | main_task=args.main_task,
269 | )
270 |
--------------------------------------------------------------------------------
/methods/SAC_Agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.distributions import Normal
8 |
9 |
10 | """
11 | Soft Actor Critic
12 | Modification of: https://github.com/XinJingHao/SAC-Continuous-Pytorch
13 | """
14 |
15 |
16 | class RandomBuffer(object):
17 | def __init__(self, state_dim, action_dim, max_size=int(1e4), device='cuda'):
18 | self.max_size = max_size
19 | self.ptr = 0
20 | self.size = 0
21 |
22 | self.state = torch.zeros((max_size, state_dim))
23 | self.action = torch.zeros((max_size, action_dim))
24 | self.reward = torch.zeros((max_size, 1))
25 | self.next_state = torch.zeros((max_size, state_dim))
26 |
27 | self.device = device
28 |
29 | def add(self, state, action, reward, next_state):
30 | self.state[self.ptr] = state
31 | self.action[self.ptr] = action
32 | self.reward[self.ptr] = reward
33 | self.next_state[self.ptr] = next_state
34 |
35 | self.ptr = (self.ptr + 1) % self.max_size
36 | self.size = min(self.size + 1, self.max_size)
37 |
38 | def clean(self):
39 | self.size = 0
40 |
41 | def sample(self, batch_size):
42 | # ind = np.random.randint(0, self.size, size=batch_size)
43 | ind = torch.randint(0, self.size, size=(1, batch_size)).squeeze()
44 | with torch.no_grad():
45 | return (
46 | self.state[ind].to(self.device),
47 | self.action[ind].to(self.device),
48 | self.reward[ind].to(self.device),
49 | self.next_state[ind].to(self.device),
50 | )
51 |
52 | # error, need to rewrite for tensor
53 | def save(self):
54 | """save the replay buffer if you want"""
55 | scaller = np.array([self.max_size, self.ptr, self.size], dtype=np.uint32)
56 | np.save("buffer/scaller.npy", scaller)
57 | np.save("buffer/state.npy", self.state)
58 | np.save("buffer/action.npy", self.action)
59 | np.save("buffer/reward.npy", self.reward)
60 | np.save("buffer/next_state.npy", self.next_state)
61 |
62 | # error, need to rewrite for tensor
63 | def load(self):
64 | scaller = np.load("buffer/scaller.npy")
65 |
66 | self.max_size = scaller[0]
67 | self.ptr = scaller[1]
68 | self.size = scaller[2]
69 |
70 | self.state = np.load("buffer/state.npy")
71 | self.action = np.load("buffer/action.npy")
72 | self.reward = np.load("buffer/reward.npy")
73 | self.next_state = np.load("buffer/next_state.npy")
74 |
75 |
76 | # Build net with for loop: multi-layer MLP with activation function
77 | def build_net(layer_shape, activation, output_activation):
78 | layers = []
79 | for j in range(len(layer_shape) - 1):
80 | act = activation if j < len(layer_shape) - 2 else output_activation
81 | layers += [nn.Linear(layer_shape[j], layer_shape[j + 1]), act()]
82 | return nn.Sequential(*layers)
83 |
84 |
85 | class Actor(nn.Module):
86 | def __init__(self, state_dim, action_dim, hid_shape, h_acti=nn.ReLU, o_acti=nn.ReLU):
87 | super(Actor, self).__init__()
88 |
89 | layers = [state_dim] + list(hid_shape)
90 | self.a_net = build_net(layers, h_acti, o_acti)
91 | self.mu_layer = nn.Linear(layers[-1], action_dim)
92 | self.log_std_layer = nn.Linear(layers[-1], action_dim)
93 |
94 | self.LOG_STD_MAX = 2
95 | self.LOG_STD_MIN = -20
96 |
97 | self.action_dim = action_dim
98 |
99 | def forward(self, state, deterministic=False, with_logprob=True):
100 | """Network with Enforcing Action Bounds"""
101 | net_out = self.a_net(state)
102 | mu = self.mu_layer(net_out)
103 | log_std = self.log_std_layer(net_out)
104 | log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
105 | std = torch.exp(log_std)
106 | dist = Normal(mu, std)
107 |
108 | if deterministic:
109 | u = mu
110 | else:
111 | u = dist.rsample() # reparameterization trick of Gaussian
112 | # a = torch.tanh(u) # dai norm_action
113 | a = self.action_dim * F.softmax(u, dim=-1)
114 |
115 | if with_logprob:
116 | # get probability density of logp_pi_a from probability density of u, which is given by the original paper.
117 | # logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True)
118 |
119 | # Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable.
120 | logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum(
121 | axis=1, keepdim=True)
122 | else:
123 | logp_pi_a = None
124 |
125 | return a, logp_pi_a
126 |
127 |
128 | class Q_Critic(nn.Module):
129 | def __init__(self, state_dim, action_dim, hid_shape):
130 | super(Q_Critic, self).__init__()
131 | layers = [state_dim + action_dim] + list(hid_shape) + [1]
132 |
133 | self.Q_1 = build_net(layers, nn.ReLU, nn.Identity)
134 | self.Q_2 = build_net(layers, nn.ReLU, nn.Identity)
135 |
136 | def forward(self, state, action):
137 | sa = torch.cat([state, action], 1)
138 | # print(sa.size())
139 | q1 = self.Q_1(sa)
140 | q2 = self.Q_2(sa)
141 | return q1, q2
142 |
143 |
144 | class SAC_Agent(nn.Module):
145 | def __init__(
146 | self,
147 | state_dim,
148 | action_dim,
149 | gamma=0.99,
150 | hid_shape=(128, 128),
151 | a_lr=3e-4,
152 | c_lr=3e-4,
153 | batch_size=256,
154 | alpha=0.2,
155 | adaptive_alpha=True,
156 | device='cuda'
157 | ):
158 |
159 | super(SAC_Agent, self).__init__()
160 | self.actor = Actor(state_dim, action_dim, hid_shape).to(device)
161 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr)
162 |
163 | self.q_critic = Q_Critic(state_dim, action_dim, hid_shape).to(device)
164 | self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr)
165 | self.q_critic_target = copy.deepcopy(self.q_critic)
166 | # Freeze target networks with respect to optimizers (only update via polyak averaging)
167 | for p in self.q_critic_target.parameters():
168 | p.requires_grad = False
169 |
170 | self.action_dim = action_dim
171 | self.gamma = gamma
172 | self.tau = 0.005
173 | self.batch_size = batch_size
174 |
175 | self.alpha = alpha
176 | self.adaptive_alpha = adaptive_alpha
177 | if adaptive_alpha:
178 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
179 | self.target_entropy = torch.tensor(-action_dim, dtype=float, requires_grad=True, device=device)
180 | # We learn log_alpha instead of alpha to ensure exp(log_alpha)=alpha>0
181 | self.log_alpha = torch.tensor(np.log(alpha), dtype=float, requires_grad=True, device=device)
182 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=c_lr)
183 |
184 | self.device = device
185 |
186 | def select_action(self, state, deterministic, with_logprob=False):
187 | # only used when interact with the env
188 | with torch.no_grad():
189 | # state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
190 | a, _ = self.actor(state, deterministic, with_logprob)
191 | # return a.cpu().numpy().flatten()
192 | return a
193 |
194 | # def train(self, replay_buffer):
195 | def train(self, replay_buffer, k):
196 | s, a, r, s_prime = replay_buffer.sample(int(self.batch_size * (k / 2)))
197 | # s, a, r, s_prime = replay_buffer.sample(self.batch_size)
198 |
199 | # ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------ #
200 | with torch.no_grad():
201 | a_prime, log_pi_a_prime = self.actor(s_prime)
202 | target_Q1, target_Q2 = self.q_critic_target(s_prime, a_prime)
203 | target_Q = torch.min(target_Q1, target_Q2)
204 | target_Q = r + self.gamma * (target_Q - self.alpha * log_pi_a_prime)
205 |
206 | # Get current Q estimates
207 | current_Q1, current_Q2 = self.q_critic(s, a)
208 |
209 | q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
210 | self.q_critic_optimizer.zero_grad()
211 | q_loss.backward()
212 | self.q_critic_optimizer.step()
213 |
214 | # ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
215 | # Freeze Q-networks so you don't waste computational effort
216 | # computing gradients for them during the policy learning step.
217 | for params in self.q_critic.parameters():
218 | params.requires_grad = False
219 |
220 | a, log_pi_a = self.actor(s)
221 | current_Q1, current_Q2 = self.q_critic(s, a)
222 | Q = torch.min(current_Q1, current_Q2)
223 |
224 | a_loss = (self.alpha * log_pi_a - Q).mean()
225 | self.actor_optimizer.zero_grad()
226 | a_loss.backward()
227 | self.actor_optimizer.step()
228 |
229 | for params in self.q_critic.parameters():
230 | params.requires_grad = True
231 | # ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
232 | if self.adaptive_alpha:
233 | # we optimize log_alpha instead of aplha, which is aimed to force alpha = exp(log_alpha)> 0
234 | # if we optimize aplpha directly, alpha might be < 0, which will lead to minimun entropy.
235 | alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean()
236 | self.alpha_optim.zero_grad()
237 | alpha_loss.backward()
238 | self.alpha_optim.step()
239 | self.alpha = self.log_alpha.exp()
240 |
241 | # ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
242 | for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
243 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
244 |
245 | def save(self, episode):
246 | torch.save(self.actor.state_dict(), "./rl_model/sac_actor{}.pth".format(episode))
247 | torch.save(self.q_critic.state_dict(), "./rl_model/sac_q_critic{}.pth".format(episode))
248 |
249 | def load(self, episode):
250 | self.actor.load_state_dict(torch.load("./rl_model/sac_actor{}.pth".format(episode)))
251 | self.q_critic.load_state_dict(torch.load("./rl_model/sac_q_critic{}.pth".format(episode)))
--------------------------------------------------------------------------------
/experiments/nyuv2/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from argparse import ArgumentParser
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | from tqdm import trange
11 |
12 | import sys
13 | sys.path.append("../..")
14 | from experiments.nyuv2.data import NYUv2
15 | from experiments.nyuv2.models import SegNet, SegNetMtan
16 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error, stl_eval_mean
17 | from experiments.utils import (
18 | common_parser,
19 | extract_weight_method_parameters_from_args,
20 | get_device,
21 | set_logger,
22 | set_seed,
23 | str2bool,
24 | )
25 |
26 | from methods.loss_weight_methods import LossWeightMethods
27 | from methods.gradient_weight_methods import GradientWeightMethods
28 |
29 | set_logger()
30 |
31 |
32 | def calc_loss(x_pred, x_output, task_type):
33 | device = x_pred.device
34 |
35 | # binary mark to mask out undefined pixel space
36 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device)
37 |
38 | if task_type == "semantic":
39 | # semantic loss: depth-wise cross entropy
40 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
41 |
42 | if task_type == "depth":
43 | # depth loss: l1 norm
44 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(
45 | binary_mask, as_tuple=False
46 | ).size(0)
47 |
48 | if task_type == "normal":
49 | # normal loss: dot product
50 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(
51 | binary_mask, as_tuple=False
52 | ).size(0)
53 |
54 | return loss
55 |
56 |
57 | def main(path, lr, bs, device):
58 | timestr = time.strftime("%Y%m%d-%H%M%S")
59 | os.makedirs("./logs", exist_ok=True)
60 | log_file = f"./logs/{timestr}_{args.loss_method}_{args.gradient_method}_seed{args.seed}_log.txt"
61 |
62 | # Nets
63 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model]
64 | model = model.to(device)
65 |
66 | # dataset and dataloaders
67 | log_str = (
68 | "Applying data augmentation on NYUv2."
69 | if args.apply_augmentation
70 | else "Standard training strategy without data augmentation."
71 | )
72 | logging.info(log_str)
73 |
74 | nyuv2_train_set = NYUv2(root=path.as_posix(), mode="train", augmentation=args.apply_augmentation)
75 | nyuv2_val_set = NYUv2(root=path.as_posix(), mode="val")
76 | nyuv2_test_set = NYUv2(root=path.as_posix(), mode="test")
77 |
78 | train_loader = DataLoader(dataset=nyuv2_train_set, batch_size=bs, shuffle=True)
79 | val_loader = DataLoader(dataset=nyuv2_val_set, batch_size=bs, shuffle=False)
80 | test_loader = DataLoader(dataset=nyuv2_test_set, batch_size=bs, shuffle=False)
81 |
82 | # loss_weight method
83 | loss_weight_methods_parameters = extract_weight_method_parameters_from_args(args)
84 | loss_weight_method = LossWeightMethods(
85 | args.loss_method, n_tasks=3, device=device, **loss_weight_methods_parameters[args.loss_method]
86 | )
87 |
88 | # gradient_weight method
89 | gradient_weight_methods_parameters = extract_weight_method_parameters_from_args(args)
90 | gradient_weight_method = GradientWeightMethods(
91 | args.gradient_method, n_tasks=3, device=device, **gradient_weight_methods_parameters[args.gradient_method]
92 | )
93 |
94 | # optimizer
95 | optimizer = torch.optim.Adam(
96 | [
97 | dict(params=model.parameters(), lr=lr),
98 | dict(params=loss_weight_method.parameters(), lr=args.method_params_lr),
99 | dict(params=gradient_weight_method.parameters(), lr=args.method_params_lr),
100 | ],
101 | )
102 |
103 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
104 |
105 | epochs = args.n_epochs
106 | epoch_iter = trange(epochs)
107 | train_batch = len(train_loader)
108 | val_batch = len(val_loader)
109 | test_batch = len(test_loader)
110 | avg_cost = np.zeros([epochs, 24], dtype=np.float32)
111 |
112 | # best model to test
113 | best_epoch = None
114 | best_eval = 0
115 |
116 | # print result head
117 | print(
118 | f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR "
119 | f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m"
120 | )
121 |
122 | train_time_sum = 0.0
123 |
124 | # train batch for IGBv2
125 | if args.loss_method == 'igbv2':
126 | loss_weight_method.method.train_batch = train_batch
127 |
128 | for epoch in epoch_iter:
129 | cost = np.zeros(24, dtype=np.float32)
130 | conf_mat = ConfMatrix(model.segnet.class_nb)
131 | avg_loss_weights = torch.zeros(3).to(device)
132 |
133 | start_train_time = time.time()
134 |
135 | # reward scale for IGBv2
136 | if args.loss_method == 'igbv2':
137 | loss_weight_method.method.reward_scale = lr / optimizer.param_groups[0]['lr']
138 |
139 | for j, batch in enumerate(train_loader):
140 | model.train()
141 | optimizer.zero_grad()
142 |
143 | train_data, train_label, train_depth, train_normal = batch
144 | train_data, train_label = train_data.to(device), train_label.long().to(device)
145 | train_depth, train_normal = train_depth.to(device), train_normal.to(device)
146 |
147 | train_pred, features = model(train_data, return_representation=True)
148 |
149 | losses = torch.stack((calc_loss(train_pred[0], train_label, "semantic"),
150 | calc_loss(train_pred[1], train_depth, "depth"),
151 | calc_loss(train_pred[2], train_normal, "normal")))
152 |
153 | weighted_losses, loss_weights = loss_weight_method.get_weighted_losses(
154 | losses=losses,
155 | shared_parameters=list(model.shared_parameters()),
156 | task_specific_parameters=list(model.task_specific_parameters()),
157 | last_shared_parameters=list(model.last_shared_parameters()),
158 | representation=features,
159 | )
160 | avg_loss_weights += loss_weights['weights'] / train_batch
161 |
162 | loss, gradient_weights = gradient_weight_method.backward(
163 | losses=weighted_losses,
164 | shared_parameters=list(model.shared_parameters()),
165 | task_specific_parameters=list(model.task_specific_parameters()),
166 | last_shared_parameters=list(model.last_shared_parameters()),
167 | representation=features,
168 | )
169 |
170 | optimizer.step()
171 |
172 | # accumulate label prediction for every pixel in training images
173 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
174 |
175 | cost[0] = losses[0].item()
176 | cost[3] = losses[1].item()
177 | cost[4], cost[5] = depth_error(train_pred[1], train_depth)
178 | cost[6] = losses[2].item()
179 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
180 | avg_cost[epoch, :12] += cost[:12] / train_batch
181 |
182 | epoch_iter.set_description(
183 | f"[{epoch} {j + 1}/{train_batch}] losses: {losses[0].item():.3f} "
184 | f"{losses[1].item():.3f} {losses[2].item():.3f} "
185 | f"weights: {loss_weights['weights'][0].item():.3f} "
186 | f"{loss_weights['weights'][1].item():.3f} {loss_weights['weights'][2].item():.3f}"
187 | )
188 |
189 | # scheduler
190 | scheduler.step()
191 | # compute mIoU and acc
192 | avg_cost[epoch, 1:3] = conf_mat.get_metrics()
193 |
194 | # base_losses for IGBv1 and IGBv2
195 | if 'igb' in args.loss_method and epoch == args.base_epoch:
196 | base_losses = torch.Tensor(avg_cost[epoch, [0, 3, 6]]).to(device)
197 | loss_weight_method.method.base_losses = base_losses
198 |
199 | end_train_time = time.time()
200 | train_time_sum += end_train_time - start_train_time
201 |
202 | # todo: move evaluate to function?
203 | # evaluating test data
204 | model.eval()
205 | conf_mat = ConfMatrix(model.segnet.class_nb)
206 | with torch.no_grad(): # operations inside don't track history
207 | for j, batch in enumerate(val_loader):
208 | val_data, val_label, val_depth, val_normal = batch
209 | val_data, val_label = val_data.to(device), val_label.long().to(device)
210 | val_depth, val_normal = val_depth.to(device), val_normal.to(device)
211 |
212 | val_pred = model(val_data)
213 | val_loss = torch.stack(
214 | (
215 | calc_loss(val_pred[0], val_label, "semantic"),
216 | calc_loss(val_pred[1], val_depth, "depth"),
217 | calc_loss(val_pred[2], val_normal, "normal"),
218 | )
219 | )
220 |
221 | conf_mat.update(val_pred[0].argmax(1).flatten(), val_label.flatten())
222 |
223 | cost[12] = val_loss[0].item()
224 | cost[15] = val_loss[1].item()
225 | cost[16], cost[17] = depth_error(val_pred[1], val_depth)
226 | cost[18] = val_loss[2].item()
227 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(val_pred[2], val_normal)
228 | avg_cost[epoch, 12:] += cost[12:] / val_batch
229 |
230 | # compute mIoU and acc
231 | avg_cost[epoch, 13:15] = conf_mat.get_metrics()
232 |
233 | # Val Delta_m
234 | val_delta_m = delta_fn(
235 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]]
236 | )
237 |
238 | if args.loss_method != "stl":
239 | eval_value = val_delta_m
240 | else:
241 | eval_value = stl_eval_mean(avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]], args.main_task)
242 |
243 | results = f"Epoch: {epoch:04d}\n" \
244 | f"AVERAGE LOSS WEIGHTS: " \
245 | f"{avg_loss_weights[0]:.4f} {avg_loss_weights[1]:.4f} {avg_loss_weights[2]:.4f}\n" \
246 | f"TRAIN: " \
247 | f"{avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]:.4f} {avg_cost[epoch, 2]:.4f} | " \
248 | f"{avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | " \
249 | f"{avg_cost[epoch, 6]:.4f} {avg_cost[epoch, 7]:.2f} {avg_cost[epoch, 8]:.2f} " \
250 | f"{avg_cost[epoch, 9]:.4f} {avg_cost[epoch, 10]:.4f} {avg_cost[epoch, 11]:.4f}\n" \
251 | f"VAL: " \
252 | f"{avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]:.4f} {avg_cost[epoch, 14]:.4f} | " \
253 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | " \
254 | f"{avg_cost[epoch, 18]:.4f} {avg_cost[epoch, 19]:.2f} {avg_cost[epoch, 20]:.2f} " \
255 | f"{avg_cost[epoch, 21]:.4f} {avg_cost[epoch, 22]:.4f} {avg_cost[epoch, 23]:.4f} | " \
256 | f"{val_delta_m:.3f}\n"
257 |
258 | if best_epoch is None or eval_value < best_eval:
259 | best_epoch = epoch
260 | best_eval = eval_value
261 |
262 | # test
263 | test_cost = np.zeros(12, dtype=np.float32)
264 | test_avg_cost = np.zeros(12, dtype=np.float32)
265 | conf_mat = ConfMatrix(model.segnet.class_nb)
266 | with torch.no_grad():
267 | for j, batch in enumerate(test_loader):
268 | test_data, test_label, test_depth, test_normal = batch
269 | test_data, test_label = test_data.to(device), test_label.long().to(device)
270 | test_depth, test_normal = test_depth.to(device), test_normal.to(device)
271 |
272 | test_pred = model(test_data)
273 | test_loss = torch.stack(
274 | (
275 | calc_loss(test_pred[0], test_label, "semantic"),
276 | calc_loss(test_pred[1], test_depth, "depth"),
277 | calc_loss(test_pred[2], test_normal, "normal"),
278 | )
279 | )
280 |
281 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
282 |
283 | test_cost[0] = test_loss[0].item()
284 | test_cost[3] = test_loss[1].item()
285 | test_cost[4], test_cost[5] = depth_error(test_pred[1], test_depth)
286 | test_cost[6] = test_loss[2].item()
287 | test_cost[7], test_cost[8], test_cost[9], test_cost[10], test_cost[11] = normal_error(
288 | test_pred[2], test_normal
289 | )
290 | test_avg_cost += test_cost / test_batch
291 |
292 | # compute mIoU and acc
293 | test_avg_cost[1:3] = conf_mat.get_metrics()
294 |
295 | # Test Delta_m
296 | test_delta_m = delta_fn(
297 | test_avg_cost[[1, 2, 4, 5, 7, 8, 9, 10, 11]]
298 | )
299 | test_result = f"TEST: {test_avg_cost[0]:.4f} {test_avg_cost[1]:.4f} {test_avg_cost[2]:.4f} | " \
300 | f"{test_avg_cost[3]:.4f} {test_avg_cost[4]:.4f} {test_avg_cost[5]:.4f} | " \
301 | f"{test_avg_cost[6]:.4f} {test_avg_cost[7]:.2f} {test_avg_cost[8]:.2f} " \
302 | f"{test_avg_cost[9]:.4f} {test_avg_cost[10]:.4f} {test_avg_cost[11]:.4f} | " \
303 | f"{test_delta_m:.3f}\n"
304 | results += test_result
305 | # print test result
306 | print(test_result, end='')
307 | with open(log_file, mode="a") as log_f:
308 | log_f.write(results)
309 |
310 | train_time_log = f"Training time: {int(train_time_sum)}s\n"
311 | print(train_time_log, end='')
312 | with open(log_file, mode="a") as log_f:
313 | log_f.write(train_time_log)
314 |
315 |
316 | if __name__ == "__main__":
317 | parser = ArgumentParser("NYUv2", parents=[common_parser])
318 | parser.set_defaults(
319 | data_path="./dataset",
320 | lr=1e-4,
321 | n_epochs=500,
322 | batch_size=2,
323 | )
324 | parser.add_argument(
325 | "--model",
326 | type=str,
327 | default="segnet",
328 | choices=["segnet", "mtan"],
329 | help="model type",
330 | )
331 | parser.add_argument(
332 | "--apply-augmentation",
333 | type=str2bool,
334 | default=True,
335 | help="data augmentations"
336 | )
337 | args = parser.parse_args()
338 |
339 | # set seed
340 | set_seed(args.seed)
341 |
342 | device = get_device(gpus=args.gpu)
343 | main(path=args.data_path, lr=args.lr, bs=args.batch_size, device=device)
344 |
--------------------------------------------------------------------------------
/methods/loss_weight_methods.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from methods.weight_method import WeightMethod, LinearScalarization
8 | from methods.SAC_Agent import SAC_Agent, RandomBuffer
9 |
10 |
11 | class ScaleInvariantLinearScalarization(WeightMethod):
12 | """Scale-invariant loss balancing paradigm"""
13 |
14 | def __init__(
15 | self,
16 | n_tasks: int,
17 | device: torch.device,
18 | task_weights: Union[List[float], torch.Tensor] = None,
19 | ):
20 | super().__init__(n_tasks, device=device)
21 | if task_weights is None:
22 | task_weights = torch.ones((n_tasks,))
23 | if not isinstance(task_weights, torch.Tensor):
24 | task_weights = torch.tensor(task_weights)
25 | assert len(task_weights) == n_tasks
26 | self.task_weights = task_weights.to(device)
27 |
28 | def get_weighted_loss(self, losses, **kwargs):
29 | loss = torch.sum(torch.log(losses) * self.task_weights)
30 | return loss, dict(weights=self.task_weights)
31 |
32 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
33 | losses = torch.log(losses) * self.task_weights
34 | return losses, dict(weights=self.task_weights)
35 |
36 |
37 | class STL(WeightMethod):
38 | """Single task learning"""
39 |
40 | def __init__(self, n_tasks, device: torch.device, main_task):
41 | super().__init__(n_tasks, device=device)
42 | self.main_task = main_task
43 | self.weights = torch.zeros(n_tasks, device=device)
44 | self.weights[main_task] = 1.0
45 |
46 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
47 | assert len(losses) == self.n_tasks
48 | loss = losses[self.main_task]
49 |
50 | return loss, dict(weights=self.weights)
51 |
52 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
53 | losses = losses * self.weights
54 | return losses, dict(weights=self.weights)
55 |
56 |
57 | class Uncertainty(WeightMethod):
58 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics`
59 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
60 | """
61 |
62 | def __init__(self, n_tasks, device: torch.device):
63 | super().__init__(n_tasks, device=device)
64 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True)
65 |
66 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
67 | loss = sum(losses / (2 * self.logsigma.exp()) + self.logsigma / 2)
68 | return loss, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights
69 |
70 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
71 | losses = losses / (2 * self.logsigma.exp()) + self.logsigma / 2
72 | return losses, dict(weights=torch.exp(-self.logsigma))
73 |
74 | def parameters(self) -> List[torch.Tensor]:
75 | return [self.logsigma]
76 |
77 |
78 | class UncertaintyLog(WeightMethod):
79 | """UW + SI"""
80 |
81 | def __init__(self, n_tasks, device: torch.device):
82 | super().__init__(n_tasks, device=device)
83 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True)
84 |
85 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
86 | loss = sum(torch.log(losses) / (2 * self.logsigma.exp()) + self.logsigma / 2)
87 | return loss, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights
88 |
89 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
90 | losses = torch.log(losses) / (2 * self.logsigma.exp()) + self.logsigma / 2
91 | return losses, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights
92 |
93 | def parameters(self) -> List[torch.Tensor]:
94 | return [self.logsigma]
95 |
96 |
97 | class RLW(WeightMethod):
98 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf"""
99 |
100 | def __init__(self, n_tasks, device: torch.device):
101 | super().__init__(n_tasks, device=device)
102 |
103 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
104 | assert len(losses) == self.n_tasks
105 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
106 | loss = torch.sum(losses * weight)
107 |
108 | return loss, dict(weights=weight)
109 |
110 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
111 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
112 | losses = losses * weight
113 | return losses, dict(weights=weight)
114 |
115 |
116 | class RLWLog(WeightMethod):
117 | """RLW + SI"""
118 |
119 | def __init__(self, n_tasks, device: torch.device):
120 | super().__init__(n_tasks, device=device)
121 |
122 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
123 | assert len(losses) == self.n_tasks
124 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
125 | loss = torch.sum(torch.log(losses) * weight)
126 |
127 | return loss, dict(weights=weight)
128 |
129 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
130 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device)
131 | losses = torch.log(losses) * weight
132 | return losses, dict(weights=weight)
133 |
134 |
135 | class DynamicWeightAverage(WeightMethod):
136 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`.
137 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242
138 | """
139 |
140 | def __init__(
141 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0
142 | ):
143 | """
144 |
145 | Parameters
146 | ----------
147 | n_tasks :
148 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses
149 | temp :
150 | """
151 | super().__init__(n_tasks, device=device)
152 | self.iteration_window = iteration_window
153 | self.temp = temp
154 | self.running_iterations = 0
155 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32)
156 | self.weights = np.ones(n_tasks, dtype=np.float32)
157 |
158 | def get_weighted_loss(self, losses, **kwargs):
159 |
160 | cost = losses.detach().cpu().numpy()
161 |
162 | # update costs - fifo
163 | self.costs[:-1, :] = self.costs[1:, :]
164 | self.costs[-1, :] = cost
165 |
166 | if self.running_iterations > self.iteration_window:
167 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0)
168 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum()
169 |
170 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device)
171 | loss = sum(task_weights * losses)
172 |
173 | self.running_iterations += 1
174 |
175 | return loss, dict(weights=task_weights)
176 |
177 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
178 | cost = losses.detach().cpu().numpy()
179 |
180 | # update costs - fifo
181 | self.costs[:-1, :] = self.costs[1:, :]
182 | self.costs[-1, :] = cost
183 |
184 | if self.running_iterations > self.iteration_window:
185 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0)
186 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum()
187 |
188 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device)
189 | losses = task_weights * losses
190 |
191 | self.running_iterations += 1
192 |
193 | return losses, dict(weights=task_weights)
194 |
195 |
196 | class DynamicWeightAverageLog(WeightMethod):
197 | """DWA + SI"""
198 | def __init__(
199 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0
200 | ):
201 | """
202 |
203 | Parameters
204 | ----------
205 | n_tasks :
206 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses
207 | temp :
208 | """
209 | super().__init__(n_tasks, device=device)
210 | self.iteration_window = iteration_window
211 | self.temp = temp
212 | self.running_iterations = 0
213 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32)
214 | self.weights = np.ones(n_tasks, dtype=np.float32)
215 |
216 | def get_weighted_loss(self, losses, **kwargs):
217 |
218 | cost = losses.detach().cpu().numpy()
219 |
220 | # update costs - fifo
221 | self.costs[:-1, :] = self.costs[1:, :]
222 | self.costs[-1, :] = cost
223 |
224 | if self.running_iterations > self.iteration_window:
225 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0)
226 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum()
227 |
228 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device)
229 | loss = sum(task_weights * torch.log(losses))
230 |
231 | self.running_iterations += 1
232 |
233 | return loss, dict(weights=task_weights)
234 |
235 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
236 | cost = losses.detach().cpu().numpy()
237 |
238 | # update costs - fifo
239 | self.costs[:-1, :] = self.costs[1:, :]
240 | self.costs[-1, :] = cost
241 |
242 | if self.running_iterations > self.iteration_window:
243 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0)
244 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum()
245 |
246 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device)
247 | losses = task_weights * torch.log(losses)
248 |
249 | self.running_iterations += 1
250 |
251 | return losses, dict(weights=task_weights)
252 |
253 |
254 | class ImprovableGapBalancing_v1(WeightMethod):
255 | def __init__(self, n_tasks, device: torch.device):
256 | super().__init__(n_tasks, device=device)
257 | self.base_losses = None
258 | self.weights = torch.ones(n_tasks).to(device)
259 |
260 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
261 | if self.base_losses is not None:
262 | self.weights = self.n_tasks * F.softmax(losses.detach() / self.base_losses, dim=-1).to(losses.device)
263 | loss = sum(self.weights * torch.log(losses))
264 | return loss, dict(weights=self.weights) # NOTE: not exactly task weights
265 |
266 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
267 | if self.base_losses is not None:
268 | self.weights = self.n_tasks * F.softmax(losses.detach() / self.base_losses, dim=-1).to(losses.device)
269 | losses = self.weights * torch.log(losses)
270 | return losses, dict(weights=self.weights) # NOTE: not exactly task weights
271 |
272 |
273 | class ImprovableGapBalancing_v2(WeightMethod):
274 | def __init__(self, n_tasks, device: torch.device, sac_lr=3e-4, buffer_size=1e4):
275 | super().__init__(n_tasks, device=device)
276 | self.base_losses = None
277 | self.weights = torch.ones(n_tasks).to(device)
278 |
279 | self.sac_model = SAC_Agent(state_dim=n_tasks, action_dim=n_tasks, a_lr=sac_lr, c_lr=sac_lr, batch_size=256, device=device)
280 | self.replay_buffer = RandomBuffer(state_dim=n_tasks, action_dim=n_tasks, max_size=buffer_size, device=device)
281 | self.custom_step = 0
282 | self.bool_custom_step = 0
283 | self.batch_loss = torch.zeros([2, n_tasks]).to(device)
284 | self.batch_rl_weight = torch.zeros([2, n_tasks]).to(device)
285 | self.train_batch = None
286 | self.start_epoch = 5
287 | self.update_after = 3
288 | self.update_every = 50
289 | self.reward_scale = 1.0
290 |
291 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs):
292 | self.batch_loss[self.bool_custom_step] = losses.detach()
293 |
294 | # write random buffer
295 | if self.base_losses is not None:
296 | loss_de = (self.batch_loss[(self.bool_custom_step - 1) % 2] - self.batch_loss[self.bool_custom_step])
297 | loss_de = loss_de / self.base_losses
298 | reward = min(loss_de)
299 | reward *= self.reward_scale
300 | self.replay_buffer.add(self.batch_loss[(self.bool_custom_step - 1) % 2],
301 | self.batch_rl_weight[(self.bool_custom_step - 1) % 2],
302 | reward,
303 | self.batch_loss[self.bool_custom_step])
304 | # train sac_model
305 | if self.custom_step >= self.update_after * self.train_batch and self.custom_step % self.update_every == 0:
306 | k = 1 + self.replay_buffer.size / self.replay_buffer.max_size
307 | for i in range(int(self.update_every * (k / 2))):
308 | self.sac_model.train(self.replay_buffer, k)
309 | # change weights
310 | if self.custom_step < self.start_epoch * self.train_batch:
311 | self.weights = self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1).to(self.device)
312 | else:
313 | self.weights = self.sac_model.select_action(self.batch_loss[self.bool_custom_step],
314 | deterministic=False,
315 | with_logprob=False)
316 | self.batch_rl_weight[self.bool_custom_step] = self.weights.detach()
317 |
318 | loss = sum(self.weights * torch.log(losses))
319 |
320 | self.custom_step += 1
321 | self.bool_custom_step = (self.bool_custom_step + 1) % 2
322 |
323 | return loss, dict(weights=self.weights) # NOTE: not exactly task weights
324 |
325 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
326 | self.batch_loss[self.bool_custom_step] = losses.detach()
327 |
328 | # write random buffer
329 | if self.base_losses is not None:
330 | loss_de = (self.batch_loss[(self.bool_custom_step - 1) % 2] - self.batch_loss[self.bool_custom_step])
331 | loss_de = loss_de / self.base_losses
332 | reward = min(loss_de)
333 | # reward = sum(loss_de) / self.n_tasks
334 | reward *= self.reward_scale
335 | self.replay_buffer.add(self.batch_loss[(self.bool_custom_step - 1) % 2],
336 | self.batch_rl_weight[(self.bool_custom_step - 1) % 2],
337 | reward,
338 | self.batch_loss[self.bool_custom_step])
339 | # train sac_model
340 | if self.custom_step >= self.update_after * self.train_batch and self.custom_step % self.update_every == 0:
341 | k = 1 + self.replay_buffer.size / self.replay_buffer.max_size
342 | for i in range(int(self.update_every * (k / 2))):
343 | self.sac_model.train(self.replay_buffer, k)
344 | # give weights
345 | if self.custom_step < self.start_epoch * self.train_batch:
346 | self.weights = self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1).to(self.device)
347 | else:
348 | self.weights = self.sac_model.select_action(self.batch_loss[self.bool_custom_step],
349 | deterministic=False,
350 | with_logprob=False)
351 | self.batch_rl_weight[self.bool_custom_step] = self.weights.detach()
352 |
353 | losses = self.weights * torch.log(losses)
354 |
355 | self.custom_step += 1
356 | self.bool_custom_step = (self.bool_custom_step + 1) % 2
357 |
358 | return losses, dict(weights=self.weights) # NOTE: not exactly task weights
359 |
360 |
361 | class LossWeightMethods:
362 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs):
363 | """
364 | :param method:
365 | """
366 | assert method in list(LOSS_METHODS.keys()), f"unknown method {method}."
367 |
368 | self.method = LOSS_METHODS[method](n_tasks=n_tasks, device=device, **kwargs)
369 |
370 | def get_weighted_loss(self, losses, **kwargs):
371 | return self.method.get_weighted_loss(losses, **kwargs)
372 |
373 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs):
374 | return self.method.get_weighted_losses(losses, **kwargs)
375 |
376 | def backward(
377 | self, losses, **kwargs
378 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
379 | return self.method.backward(losses, **kwargs)
380 |
381 | def __ceil__(self, losses, **kwargs):
382 | return self.backward(losses, **kwargs)
383 |
384 | def parameters(self):
385 | return self.method.parameters()
386 |
387 |
388 | LOSS_METHODS = dict(
389 | ls=LinearScalarization,
390 | stl=STL,
391 | si=ScaleInvariantLinearScalarization,
392 | uw=Uncertainty,
393 | uwlog=UncertaintyLog,
394 | rlw=RLW,
395 | rlwlog=RLWLog,
396 | dwa=DynamicWeightAverage,
397 | dwalog=DynamicWeightAverageLog,
398 | igbv1=ImprovableGapBalancing_v1,
399 | igbv2=ImprovableGapBalancing_v2,
400 | )
401 |
--------------------------------------------------------------------------------
/methods/gradient_weight_methods.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from typing import Dict, List, Tuple, Union
4 |
5 | import cvxpy as cp
6 | import numpy as np
7 | import torch
8 | from scipy.optimize import minimize
9 |
10 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers
11 | from methods.weight_method import WeightMethod, LinearScalarization
12 |
13 |
14 | class NashMTL(WeightMethod):
15 | def __init__(
16 | self,
17 | n_tasks: int,
18 | device: torch.device,
19 | max_norm: float = 1.0,
20 | update_weights_every: int = 1,
21 | optim_niter=20,
22 | ):
23 | super(NashMTL, self).__init__(n_tasks=n_tasks, device=device)
24 |
25 | self.optim_niter = optim_niter
26 | self.update_weights_every = update_weights_every
27 | self.max_norm = max_norm
28 |
29 | self.prvs_alpha_param = None
30 | self.normalization_factor = np.ones((1,))
31 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks)
32 | self.step = 0.0
33 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
34 |
35 | def _stop_criteria(self, gtg, alpha_t):
36 | return ((self.alpha_param.value is None)
37 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
38 | or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6))
39 |
40 | def solve_optimization(self, gtg: np.array):
41 | self.G_param.value = gtg
42 | self.normalization_factor_param.value = self.normalization_factor
43 |
44 | alpha_t = self.prvs_alpha
45 | for _ in range(self.optim_niter):
46 | self.alpha_param.value = alpha_t
47 | self.prvs_alpha_param.value = alpha_t
48 |
49 | try:
50 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
51 | except:
52 | self.alpha_param.value = self.prvs_alpha_param.value
53 |
54 | if self._stop_criteria(gtg, alpha_t):
55 | break
56 |
57 | alpha_t = self.alpha_param.value
58 |
59 | if alpha_t is not None:
60 | self.prvs_alpha = alpha_t
61 |
62 | return self.prvs_alpha
63 |
64 | def _calc_phi_alpha_linearization(self):
65 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param
66 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param
67 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param)
68 | return phi_alpha
69 |
70 | def _init_optim_problem(self):
71 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True)
72 | self.prvs_alpha_param = cp.Parameter(shape=(self.n_tasks,), value=self.prvs_alpha)
73 | self.G_param = cp.Parameter(shape=(self.n_tasks, self.n_tasks), value=self.init_gtg)
74 | self.normalization_factor_param = cp.Parameter(shape=(1,), value=np.array([1.0]))
75 |
76 | self.phi_alpha = self._calc_phi_alpha_linearization()
77 |
78 | G_alpha = self.G_param @ self.alpha_param
79 | constraint = []
80 | for i in range(self.n_tasks):
81 | constraint.append(-cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i]) <= 0)
82 | obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)
83 | self.prob = cp.Problem(obj, constraint)
84 |
85 | def get_weighted_loss(
86 | self,
87 | losses,
88 | shared_parameters,
89 | **kwargs,
90 | ):
91 | """
92 |
93 | Parameters
94 | ----------
95 | losses :
96 | shared_parameters : shared parameters
97 | kwargs :
98 |
99 | Returns
100 | -------
101 |
102 | """
103 |
104 | extra_outputs = dict()
105 | if self.step == 0:
106 | self._init_optim_problem()
107 |
108 | if self.step % self.update_weights_every == 0:
109 | self.step += 1
110 |
111 | grads = {}
112 | for i, loss in enumerate(losses):
113 | g = list(torch.autograd.grad(loss, shared_parameters, retain_graph=True))
114 | grad = torch.cat([torch.flatten(grad) for grad in g])
115 | grads[i] = grad
116 |
117 | G = torch.stack(tuple(v for v in grads.values()))
118 | GTG = torch.mm(G, G.t())
119 |
120 | self.normalization_factor = (torch.norm(GTG).detach().cpu().numpy().reshape((1,)))
121 | GTG = GTG / self.normalization_factor.item()
122 | alpha = self.solve_optimization(GTG.cpu().detach().numpy())
123 | alpha = torch.from_numpy(alpha)
124 |
125 | else:
126 | self.step += 1
127 | alpha = self.prvs_alpha
128 |
129 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])
130 | extra_outputs["weights"] = alpha
131 | return weighted_loss, extra_outputs
132 |
133 | def backward(
134 | self,
135 | losses: torch.Tensor,
136 | shared_parameters: Union[
137 | List[torch.nn.parameter.Parameter], torch.Tensor
138 | ] = None,
139 | task_specific_parameters: Union[
140 | List[torch.nn.parameter.Parameter], torch.Tensor
141 | ] = None,
142 | last_shared_parameters: Union[
143 | List[torch.nn.parameter.Parameter], torch.Tensor
144 | ] = None,
145 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
146 | **kwargs,
147 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
148 | loss, extra_outputs = self.get_weighted_loss(
149 | losses=losses,
150 | shared_parameters=shared_parameters,
151 | **kwargs,
152 | )
153 | loss.backward()
154 |
155 | # make sure the solution for shared params has norm <= self.eps
156 | if self.max_norm > 0:
157 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm)
158 |
159 | return loss, extra_outputs
160 |
161 |
162 | class MGDA(WeightMethod):
163 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization
164 | Ozan Sener, Vladlen Koltun
165 | Neural Information Processing Systems (NeurIPS) 2018
166 | https://github.com/intel-isl/MultiObjectiveOptimization
167 |
168 | """
169 |
170 | def __init__(
171 | self, n_tasks, device: torch.device, params="shared", normalization="none"
172 | ):
173 | super().__init__(n_tasks, device=device)
174 | self.solver = MinNormSolver()
175 | assert params in ["shared", "last", "rep"]
176 | self.params = params
177 | assert normalization in ["norm", "loss", "loss+", "none"]
178 | self.normalization = normalization
179 |
180 | @staticmethod
181 | def _flattening(grad):
182 | return torch.cat(tuple(g.reshape(-1,) for i, g in enumerate(grad)), dim=0)
183 |
184 | def get_weighted_loss(
185 | self,
186 | losses,
187 | shared_parameters=None,
188 | last_shared_parameters=None,
189 | representation=None,
190 | **kwargs,
191 | ):
192 | """
193 |
194 | Parameters
195 | ----------
196 | losses :
197 | shared_parameters :
198 | last_shared_parameters :
199 | representation :
200 | kwargs :
201 |
202 | Returns
203 | -------
204 |
205 | """
206 | # Our code
207 | grads = {}
208 | params = dict(rep=representation, shared=shared_parameters, last=last_shared_parameters)[self.params]
209 | for i, loss in enumerate(losses):
210 | g = list(torch.autograd.grad(loss, params,retain_graph=True))
211 | # Normalize all gradients, this is optional and not included in the paper.
212 |
213 | grads[i] = [torch.flatten(grad) for grad in g]
214 |
215 | gn = gradient_normalizers(grads, losses, self.normalization)
216 | for t in range(self.n_tasks):
217 | for gr_i in range(len(grads[t])):
218 | grads[t][gr_i] = grads[t][gr_i] / gn[t]
219 |
220 | sol, min_norm = self.solver.find_min_norm_element([grads[t] for t in range(len(grads))])
221 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks
222 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))])
223 |
224 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32)))
225 |
226 |
227 | class PCGrad(WeightMethod):
228 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py
229 |
230 | @misc{Pytorch-PCGrad,
231 | author = {Wei-Cheng Tseng},
232 | title = {WeiChengTseng/Pytorch-PCGrad},
233 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git},
234 | year = {2020}
235 | }
236 |
237 | """
238 |
239 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"):
240 | super().__init__(n_tasks, device=device)
241 | assert reduction in ["mean", "sum"]
242 | self.reduction = reduction
243 |
244 | def get_weighted_loss(
245 | self,
246 | losses: torch.Tensor,
247 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
248 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
249 | **kwargs,
250 | ):
251 | raise NotImplementedError
252 |
253 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None):
254 | # shared part
255 | shared_grads = []
256 | for l in losses:
257 | shared_grads.append(torch.autograd.grad(l, shared_parameters, retain_graph=True))
258 |
259 | if isinstance(shared_parameters, torch.Tensor):
260 | shared_parameters = [shared_parameters]
261 | non_conflict_shared_grads = self._project_conflicting(shared_grads)
262 | for p, g in zip(shared_parameters, non_conflict_shared_grads):
263 | p.grad = g
264 |
265 | # task specific part
266 | if task_specific_parameters is not None:
267 | task_specific_grads = torch.autograd.grad(losses.sum(), task_specific_parameters)
268 | if isinstance(task_specific_parameters, torch.Tensor):
269 | task_specific_parameters = [task_specific_parameters]
270 | for p, g in zip(task_specific_parameters, task_specific_grads):
271 | p.grad = g
272 |
273 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]):
274 | pc_grad = copy.deepcopy(grads)
275 | for g_i in pc_grad:
276 | random.shuffle(grads)
277 | for g_j in grads:
278 | g_i_g_j = sum([torch.dot(torch.flatten(grad_i), torch.flatten(grad_j))
279 | for grad_i, grad_j in zip(g_i, g_j)])
280 | if g_i_g_j < 0:
281 | g_j_norm_square = (torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2)
282 | for grad_i, grad_j in zip(g_i, g_j):
283 | grad_i -= g_i_g_j * grad_j / g_j_norm_square
284 |
285 | merged_grad = [sum(g) for g in zip(*pc_grad)]
286 | if self.reduction == "mean":
287 | merged_grad = [g / self.n_tasks for g in merged_grad]
288 |
289 | return merged_grad
290 |
291 | def backward(
292 | self,
293 | losses: torch.Tensor,
294 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
295 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
296 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
297 | **kwargs,
298 | ):
299 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters)
300 | return None, {} # NOTE: to align with all other weight methods
301 |
302 |
303 | class CAGrad(WeightMethod):
304 | def __init__(self, n_tasks, device: torch.device, c=0.4):
305 | super().__init__(n_tasks, device=device)
306 | self.c = c
307 |
308 | def get_weighted_loss(
309 | self,
310 | losses,
311 | shared_parameters,
312 | **kwargs,
313 | ):
314 | """
315 | Parameters
316 | ----------
317 | losses :
318 | shared_parameters : shared parameters
319 | kwargs :
320 | Returns
321 | -------
322 | """
323 | # NOTE: we allow only shared params for now. Need to see paper for other options.
324 | grad_dims = []
325 | for param in shared_parameters:
326 | grad_dims.append(param.data.numel())
327 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device)
328 |
329 | for i in range(self.n_tasks):
330 | if i < self.n_tasks:
331 | losses[i].backward(retain_graph=True)
332 | else:
333 | losses[i].backward()
334 | self.grad2vec(shared_parameters, grads, grad_dims, i)
335 | # multi_task_model.zero_grad_shared_modules()
336 | for p in shared_parameters:
337 | p.grad = None
338 |
339 | g = self.cagrad(grads, alpha=self.c, rescale=1)
340 | self.overwrite_grad(shared_parameters, g, grad_dims)
341 |
342 | def cagrad(self, grads, alpha=0.5, rescale=1):
343 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks]
344 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient
345 |
346 | x_start = np.ones(self.n_tasks) / self.n_tasks
347 | bnds = tuple((0, 1) for x in x_start)
348 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)}
349 | A = GG.numpy()
350 | b = x_start.copy()
351 | c = (alpha * g0_norm + 1e-8).item()
352 |
353 | def objfn(x):
354 | return (x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1))
355 | + c * np.sqrt(x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1)) + 1e-8)).sum()
356 |
357 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
358 | w_cpu = res.x
359 | ww = torch.Tensor(w_cpu).to(grads.device)
360 | gw = (grads * ww.view(1, -1)).sum(1)
361 | gw_norm = gw.norm()
362 | lmbda = c / (gw_norm + 1e-8)
363 | g = grads.mean(1) + lmbda * gw
364 | if rescale == 0:
365 | return g
366 | elif rescale == 1:
367 | return g / (1 + alpha ** 2)
368 | else:
369 | return g / (1 + alpha)
370 |
371 | @staticmethod
372 | def grad2vec(shared_params, grads, grad_dims, task):
373 | # store the gradients
374 | grads[:, task].fill_(0.0)
375 | cnt = 0
376 | # for mm in m.shared_modules():
377 | # for p in mm.parameters():
378 |
379 | for param in shared_params:
380 | grad = param.grad
381 | if grad is not None:
382 | grad_cur = grad.data.detach().clone()
383 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
384 | en = sum(grad_dims[: cnt + 1])
385 | grads[beg:en, task].copy_(grad_cur.data.view(-1))
386 | cnt += 1
387 |
388 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims):
389 | newgrad = newgrad * self.n_tasks # to match the sum loss
390 | cnt = 0
391 |
392 | # for mm in m.shared_modules():
393 | # for param in mm.parameters():
394 | for param in shared_parameters:
395 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
396 | en = sum(grad_dims[: cnt + 1])
397 | this_grad = newgrad[beg:en].contiguous().view(param.data.size())
398 | param.grad = this_grad.data.clone()
399 | cnt += 1
400 |
401 | def backward(
402 | self,
403 | losses: torch.Tensor,
404 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
405 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
406 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
407 | **kwargs,
408 | ):
409 | self.get_weighted_loss(losses, shared_parameters)
410 | return None, {} # NOTE: to align with all other weight methods
411 |
412 |
413 | class IMTLG(WeightMethod):
414 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr"""
415 |
416 | def __init__(self, n_tasks, device: torch.device):
417 | super().__init__(n_tasks, device=device)
418 |
419 | def get_weighted_loss(
420 | self,
421 | losses,
422 | shared_parameters,
423 | **kwargs,
424 | ):
425 | grads = {}
426 | norm_grads = {}
427 |
428 | for i, loss in enumerate(losses):
429 | g = list(torch.autograd.grad(loss, shared_parameters, retain_graph=True))
430 | grad = torch.cat([torch.flatten(grad) for grad in g])
431 | norm_term = torch.norm(grad)
432 |
433 | grads[i] = grad
434 | norm_grads[i] = grad / norm_term
435 |
436 | G = torch.stack(tuple(v for v in grads.values()))
437 | D = (
438 | G[
439 | 0,
440 | ]
441 | - G[
442 | 1:,
443 | ]
444 | )
445 |
446 | U = torch.stack(tuple(v for v in norm_grads.values()))
447 | U = (
448 | U[
449 | 0,
450 | ]
451 | - U[
452 | 1:,
453 | ]
454 | )
455 | first_element = torch.matmul(
456 | G[
457 | 0,
458 | ],
459 | U.t(),
460 | )
461 | try:
462 | second_element = torch.inverse(torch.matmul(D, U.t()))
463 | except:
464 | # workaround for cases where matrix is singular
465 | second_element = torch.inverse(
466 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8
467 | + torch.matmul(D, U.t())
468 | )
469 |
470 | alpha_ = torch.matmul(first_element, second_element)
471 | alpha = torch.cat(
472 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_)
473 | )
474 |
475 | loss = torch.sum(losses * alpha)
476 |
477 | return loss, dict(weights=alpha)
478 |
479 |
480 | class GradientWeightMethods:
481 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs):
482 | """
483 | :param method:
484 | """
485 | assert method in list(GRADIENT_METHODS.keys()), f"unknown method {method}."
486 |
487 | self.method = GRADIENT_METHODS[method](n_tasks=n_tasks, device=device, **kwargs)
488 |
489 | def get_weighted_loss(self, losses, **kwargs):
490 | return self.method.get_weighted_loss(losses, **kwargs)
491 |
492 | def backward(
493 | self, losses, **kwargs
494 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
495 | return self.method.backward(losses, **kwargs)
496 |
497 | def __ceil__(self, losses, **kwargs):
498 | return self.backward(losses, **kwargs)
499 |
500 | def parameters(self):
501 | return self.method.parameters()
502 |
503 |
504 | GRADIENT_METHODS = dict(
505 | ls=LinearScalarization,
506 | pcgrad=PCGrad,
507 | mgda=MGDA,
508 | cagrad=CAGrad,
509 | nashmtl=NashMTL,
510 | imtl=IMTLG,
511 | )
512 |
--------------------------------------------------------------------------------
/experiments/nyuv2/models.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class _SegNet(nn.Module):
9 | """SegNet MTAN"""
10 | def __init__(self):
11 | super(_SegNet, self).__init__()
12 | # initialise network parameters
13 | filter = [64, 128, 256, 512, 512]
14 | self.class_nb = 13
15 |
16 | # define encoder decoder layers
17 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
18 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
19 | for i in range(4):
20 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
21 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
22 |
23 | # define convolution layer
24 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
25 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
26 | for i in range(4):
27 | if i == 0:
28 | self.conv_block_enc.append(
29 | self.conv_layer([filter[i + 1], filter[i + 1]])
30 | )
31 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
32 | else:
33 | self.conv_block_enc.append(
34 | nn.Sequential(
35 | self.conv_layer([filter[i + 1], filter[i + 1]]),
36 | self.conv_layer([filter[i + 1], filter[i + 1]]),
37 | )
38 | )
39 | self.conv_block_dec.append(
40 | nn.Sequential(
41 | self.conv_layer([filter[i], filter[i]]),
42 | self.conv_layer([filter[i], filter[i]]),
43 | )
44 | )
45 |
46 | # define task attention layers
47 | self.encoder_att = nn.ModuleList(
48 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])]
49 | )
50 | self.decoder_att = nn.ModuleList(
51 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])]
52 | )
53 | self.encoder_block_att = nn.ModuleList(
54 | [self.conv_layer([filter[0], filter[1]])]
55 | )
56 | self.decoder_block_att = nn.ModuleList(
57 | [self.conv_layer([filter[0], filter[0]])]
58 | )
59 |
60 | for j in range(3):
61 | if j < 2:
62 | self.encoder_att.append(
63 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])
64 | )
65 | self.decoder_att.append(
66 | nn.ModuleList(
67 | [self.att_layer([2 * filter[0], filter[0], filter[0]])]
68 | )
69 | )
70 | for i in range(4):
71 | self.encoder_att[j].append(
72 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]])
73 | )
74 | self.decoder_att[j].append(
75 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]])
76 | )
77 |
78 | for i in range(4):
79 | if i < 3:
80 | self.encoder_block_att.append(
81 | self.conv_layer([filter[i + 1], filter[i + 2]])
82 | )
83 | self.decoder_block_att.append(
84 | self.conv_layer([filter[i + 1], filter[i]])
85 | )
86 | else:
87 | self.encoder_block_att.append(
88 | self.conv_layer([filter[i + 1], filter[i + 1]])
89 | )
90 | self.decoder_block_att.append(
91 | self.conv_layer([filter[i + 1], filter[i + 1]])
92 | )
93 |
94 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
95 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)
96 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)
97 |
98 | # define pooling and unpooling functions
99 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
100 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
101 |
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | nn.init.xavier_normal_(m.weight)
105 | nn.init.constant_(m.bias, 0)
106 | elif isinstance(m, nn.BatchNorm2d):
107 | nn.init.constant_(m.weight, 1)
108 | nn.init.constant_(m.bias, 0)
109 | elif isinstance(m, nn.Linear):
110 | nn.init.xavier_normal_(m.weight)
111 | nn.init.constant_(m.bias, 0)
112 |
113 | def shared_modules(self):
114 | return [
115 | self.encoder_block,
116 | self.decoder_block,
117 | self.conv_block_enc,
118 | self.conv_block_dec,
119 | # self.encoder_att, self.decoder_att,
120 | self.encoder_block_att,
121 | self.decoder_block_att,
122 | self.down_sampling,
123 | self.up_sampling,
124 | ]
125 |
126 | def zero_grad_shared_modules(self):
127 | for mm in self.shared_modules():
128 | mm.zero_grad()
129 |
130 | def conv_layer(self, channel, pred=False):
131 | if not pred:
132 | conv_block = nn.Sequential(
133 | nn.Conv2d(
134 | in_channels=channel[0],
135 | out_channels=channel[1],
136 | kernel_size=3,
137 | padding=1,
138 | ),
139 | nn.BatchNorm2d(num_features=channel[1]),
140 | nn.ReLU(inplace=True),
141 | )
142 | else:
143 | conv_block = nn.Sequential(
144 | nn.Conv2d(
145 | in_channels=channel[0],
146 | out_channels=channel[0],
147 | kernel_size=3,
148 | padding=1,
149 | ),
150 | nn.Conv2d(
151 | in_channels=channel[0],
152 | out_channels=channel[1],
153 | kernel_size=1,
154 | padding=0,
155 | ),
156 | )
157 | return conv_block
158 |
159 | def att_layer(self, channel):
160 | att_block = nn.Sequential(
161 | nn.Conv2d(
162 | in_channels=channel[0],
163 | out_channels=channel[1],
164 | kernel_size=1,
165 | padding=0,
166 | ),
167 | nn.BatchNorm2d(channel[1]),
168 | nn.ReLU(inplace=True),
169 | nn.Conv2d(
170 | in_channels=channel[1],
171 | out_channels=channel[2],
172 | kernel_size=1,
173 | padding=0,
174 | ),
175 | nn.BatchNorm2d(channel[2]),
176 | nn.Sigmoid(),
177 | )
178 | return att_block
179 |
180 | def forward(self, x):
181 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
182 | [0] * 5 for _ in range(5)
183 | )
184 | for i in range(5):
185 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
186 |
187 | # define attention list for tasks
188 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2))
189 | for i in range(3):
190 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
191 | for i in range(3):
192 | for j in range(5):
193 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))
194 |
195 | # define global shared network
196 | for i in range(5):
197 | if i == 0:
198 | g_encoder[i][0] = self.encoder_block[i](x)
199 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
200 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
201 | else:
202 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
203 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
204 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
205 |
206 | for i in range(5):
207 | if i == 0:
208 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
209 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
210 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
211 | else:
212 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
213 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
214 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
215 |
216 | # define task dependent attention module
217 | for i in range(3):
218 | for j in range(5):
219 | if j == 0:
220 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
221 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
222 | atten_encoder[i][j][2] = self.encoder_block_att[j](
223 | atten_encoder[i][j][1]
224 | )
225 | atten_encoder[i][j][2] = F.max_pool2d(
226 | atten_encoder[i][j][2], kernel_size=2, stride=2
227 | )
228 | else:
229 | atten_encoder[i][j][0] = self.encoder_att[i][j](
230 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1)
231 | )
232 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
233 | atten_encoder[i][j][2] = self.encoder_block_att[j](
234 | atten_encoder[i][j][1]
235 | )
236 | atten_encoder[i][j][2] = F.max_pool2d(
237 | atten_encoder[i][j][2], kernel_size=2, stride=2
238 | )
239 |
240 | for j in range(5):
241 | if j == 0:
242 | atten_decoder[i][j][0] = F.interpolate(
243 | atten_encoder[i][-1][-1],
244 | scale_factor=2,
245 | mode="bilinear",
246 | align_corners=True,
247 | )
248 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
249 | atten_decoder[i][j][0]
250 | )
251 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
252 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
253 | )
254 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
255 | else:
256 | atten_decoder[i][j][0] = F.interpolate(
257 | atten_decoder[i][j - 1][2],
258 | scale_factor=2,
259 | mode="bilinear",
260 | align_corners=True,
261 | )
262 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](
263 | atten_decoder[i][j][0]
264 | )
265 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
266 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)
267 | )
268 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
269 |
270 | # define task prediction layers
271 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1)
272 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1])
273 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1])
274 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
275 |
276 | return (
277 | [t1_pred, t2_pred, t3_pred],
278 | (
279 | atten_decoder[0][-1][-1],
280 | atten_decoder[1][-1][-1],
281 | atten_decoder[2][-1][-1],
282 | ),
283 | )
284 |
285 |
286 | class SegNetMtan(nn.Module):
287 | def __init__(self):
288 | super().__init__()
289 | self.segnet = _SegNet()
290 |
291 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
292 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
293 |
294 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
295 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
296 |
297 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
298 | """Parameters of the last shared layer.
299 | Returns
300 | -------
301 | """
302 | return []
303 |
304 | def forward(self, x, return_representation=False):
305 | if return_representation:
306 | return self.segnet(x)
307 | else:
308 | pred, rep = self.segnet(x)
309 | return pred
310 |
311 |
312 | class SegNetSplit(nn.Module):
313 | def __init__(self, model_type="standard"):
314 | super(SegNetSplit, self).__init__()
315 | # initialise network parameters
316 | assert model_type in ["standard", "wide", "deep"]
317 | self.model_type = model_type
318 | if self.model_type == "wide":
319 | filter = [64, 128, 256, 512, 1024]
320 | else:
321 | filter = [64, 128, 256, 512, 512]
322 |
323 | self.class_nb = 13
324 |
325 | # define encoder decoder layers
326 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
327 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
328 | for i in range(4):
329 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
330 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
331 |
332 | # define convolution layer
333 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
334 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
335 | for i in range(4):
336 | if i == 0:
337 | self.conv_block_enc.append(
338 | self.conv_layer([filter[i + 1], filter[i + 1]])
339 | )
340 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
341 | else:
342 | self.conv_block_enc.append(
343 | nn.Sequential(
344 | self.conv_layer([filter[i + 1], filter[i + 1]]),
345 | self.conv_layer([filter[i + 1], filter[i + 1]]),
346 | )
347 | )
348 | self.conv_block_dec.append(
349 | nn.Sequential(
350 | self.conv_layer([filter[i], filter[i]]),
351 | self.conv_layer([filter[i], filter[i]]),
352 | )
353 | )
354 |
355 | # define task specific layers
356 | self.pred_task1 = nn.Sequential(
357 | nn.Conv2d(
358 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
359 | ),
360 | nn.Conv2d(
361 | in_channels=filter[0],
362 | out_channels=self.class_nb,
363 | kernel_size=1,
364 | padding=0,
365 | ),
366 | )
367 | self.pred_task2 = nn.Sequential(
368 | nn.Conv2d(
369 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
370 | ),
371 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0),
372 | )
373 | self.pred_task3 = nn.Sequential(
374 | nn.Conv2d(
375 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
376 | ),
377 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0),
378 | )
379 |
380 | # define pooling and unpooling functions
381 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
382 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
383 |
384 | for m in self.modules():
385 | if isinstance(m, nn.Conv2d):
386 | nn.init.xavier_normal_(m.weight)
387 | nn.init.constant_(m.bias, 0)
388 | elif isinstance(m, nn.BatchNorm2d):
389 | nn.init.constant_(m.weight, 1)
390 | nn.init.constant_(m.bias, 0)
391 | elif isinstance(m, nn.Linear):
392 | nn.init.xavier_normal_(m.weight)
393 | nn.init.constant_(m.bias, 0)
394 |
395 | # define convolutional block
396 | def conv_layer(self, channel):
397 | if self.model_type == "deep":
398 | conv_block = nn.Sequential(
399 | nn.Conv2d(
400 | in_channels=channel[0],
401 | out_channels=channel[1],
402 | kernel_size=3,
403 | padding=1,
404 | ),
405 | nn.BatchNorm2d(num_features=channel[1]),
406 | nn.ReLU(inplace=True),
407 | nn.Conv2d(
408 | in_channels=channel[1],
409 | out_channels=channel[1],
410 | kernel_size=3,
411 | padding=1,
412 | ),
413 | nn.BatchNorm2d(num_features=channel[1]),
414 | nn.ReLU(inplace=True),
415 | )
416 | else:
417 | conv_block = nn.Sequential(
418 | nn.Conv2d(
419 | in_channels=channel[0],
420 | out_channels=channel[1],
421 | kernel_size=3,
422 | padding=1,
423 | ),
424 | nn.BatchNorm2d(num_features=channel[1]),
425 | nn.ReLU(inplace=True),
426 | )
427 | return conv_block
428 |
429 | def forward(self, x):
430 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
431 | [0] * 5 for _ in range(5)
432 | )
433 | for i in range(5):
434 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
435 |
436 | # global shared encoder-decoder network
437 | for i in range(5):
438 | if i == 0:
439 | g_encoder[i][0] = self.encoder_block[i](x)
440 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
441 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
442 | else:
443 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
444 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
445 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
446 |
447 | for i in range(5):
448 | if i == 0:
449 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
450 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
451 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
452 | else:
453 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
454 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
455 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
456 |
457 | # define task prediction layers
458 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1)
459 | t2_pred = self.pred_task2(g_decoder[i][1])
460 | t3_pred = self.pred_task3(g_decoder[i][1])
461 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
462 |
463 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][
464 | 1
465 | ] # NOTE: last element is representation
466 |
467 |
468 | class SegNet(nn.Module):
469 | def __init__(self):
470 | super().__init__()
471 | self.segnet = SegNetSplit()
472 |
473 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
474 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n)
475 |
476 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]:
477 | return (p for n, p in self.segnet.named_parameters() if "pred" in n)
478 |
479 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]:
480 | """Parameters of the last shared layer.
481 | Returns
482 | -------
483 | """
484 | return self.segnet.conv_block_dec[-5].parameters()
485 |
486 | def forward(self, x, return_representation=False):
487 | if return_representation:
488 | return self.segnet(x)
489 | else:
490 | pred, rep = self.segnet(x)
491 | return pred
492 |
--------------------------------------------------------------------------------